Skip to content

Commit

Permalink
switching constitutive model interface to use grad_u in all kinematic…
Browse files Browse the repository at this point in the history
… variable interfaces.
  • Loading branch information
cmhamel committed Feb 28, 2025
1 parent e94baca commit 7d1f59a
Show file tree
Hide file tree
Showing 13 changed files with 88 additions and 309 deletions.
56 changes: 35 additions & 21 deletions pancax/constitutive_models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,83 +7,97 @@


Scalar = float
State = Float[Array, "ns"]
Tensor = Float[Array, "3 3"]


class BaseConstitutiveModel(eqx.Module):
def cauchy_stress(self, F: Tensor) -> Tensor:
# F = grad_u + jnp.eye(3)
J = self.jacobian(F)
P = self.pk1_stress(F)
def cauchy_stress(self, grad_u: Tensor) -> Tensor:
F = grad_u + jnp.eye(3)
J = self.jacobian(grad_u)
P = self.pk1_stress(grad_u)
return (1. / J) * P @ F.T

def deformation_gradient(self, grad_u: Tensor) -> Tensor:
F = grad_u + jnp.eye(3)
return F

@abstractmethod
def energy(self, F: Tensor) -> Scalar:
def energy(self, grad_u: Tensor) -> Scalar:
"""
This method returns the algorithmic strain energy density.
"""
pass

def invariants(self, F: Tensor) -> Tuple[Scalar, Scalar, Scalar]:
I1 = self.I1(F)
I2 = self.I2(F)
I3 = self.jacobian(F)**2
def invariants(self, grad_u: Tensor) -> Tuple[Scalar, Scalar, Scalar]:
I1 = self.I1(grad_u)
I2 = self.I2(grad_u)
I3 = self.I3(grad_u)

Check warning on line 35 in pancax/constitutive_models/base.py

View check run for this annotation

Codecov / codecov/patch

pancax/constitutive_models/base.py#L33-L35

Added lines #L33 - L35 were not covered by tests
return jnp.array([I1, I2, I3])

def I1(self, F: Tensor) -> Scalar:
def I1(self, grad_u: Tensor) -> Scalar:
r"""
Calculates the first invariant
- **F**: the deformation gradient
- **grad_u**: the displacement gradient
$$
I_1 = tr\left(\mathbf{F}^T\mathbf{F}\right)
$$
"""
F = self.deformation_gradient(grad_u)

Check warning on line 48 in pancax/constitutive_models/base.py

View check run for this annotation

Codecov / codecov/patch

pancax/constitutive_models/base.py#L48

Added line #L48 was not covered by tests
I1 = jnp.trace(F @ F.T)
return I1

def I1_bar(self, F: Tensor) -> Scalar:
def I1_bar(self, grad_u: Tensor) -> Scalar:
r"""
Calculates the first distortional invariant
- **F**: the deformation gradient
- **grad_u**: the displacement gradient
$$
\bar{I}_1 = J^{-2/3}tr\left(\mathbf{F}^T\mathbf{F}\right)
$$
"""
F = self.deformation_gradient(grad_u)
I1 = jnp.trace(F @ F.T)
J = self.jacobian(F)
J = self.jacobian(grad_u)
return jnp.power(J, -2. / 3.) * I1

def I2(self, F: Tensor) -> Scalar:
def I2(self, grad_u: Tensor) -> Scalar:
F = self.deformation_gradient(grad_u)

Check warning on line 68 in pancax/constitutive_models/base.py

View check run for this annotation

Codecov / codecov/patch

pancax/constitutive_models/base.py#L68

Added line #L68 was not covered by tests
C = F.T @ F
C2 = C @ C
I1 = jnp.trace(C)
I2 = 0.5 * (I1**2 - jnp.trace(C2))
return I2

def I2_bar(self, F: Tensor) -> Scalar:
def I2_bar(self, grad_u: Tensor) -> Scalar:
F = self.deformation_gradient(grad_u)
C = F.T @ F
C2 = C @ C
I1 = jnp.trace(C)
I2 = 0.5 * (I1**2 - jnp.trace(C2))
J = self.jacobian(F)
J = self.jacobian(grad_u)
return jnp.power(J, -4. / 3.) * I2

def jacobian(self, F: Tensor) -> Scalar:
def I3(self, grad_u: Tensor) -> Scalar:
J = self.jacobian(grad_u)
return J * J

Check warning on line 86 in pancax/constitutive_models/base.py

View check run for this annotation

Codecov / codecov/patch

pancax/constitutive_models/base.py#L85-L86

Added lines #L85 - L86 were not covered by tests

def jacobian(self, grad_u: Tensor) -> Scalar:
r"""
This simply calculate the jacobian but with guard rails
to return nonsensical numbers if a non-positive jacobian
is encountered during training.
- **F**: the deformation gradient
- **grad_u**: the displacement gradient
$$
J = det(\mathbf{F})
$$
"""
F = self.deformation_gradient(grad_u)
J = jnp.linalg.det(F)
J = jax.lax.cond(
J <= 0.0,
Expand All @@ -93,8 +107,8 @@ def jacobian(self, F: Tensor) -> Scalar:
)
return J

def pk1_stress(self, F: Tensor) -> Tensor:
return jax.grad(self.energy, argnums=0)(F)
def pk1_stress(self, grad_u: Tensor) -> Tensor:
return jax.grad(self.energy, argnums=0)(grad_u)

def properties(self):
return self.__dataclass_fields__
6 changes: 3 additions & 3 deletions pancax/constitutive_models/blatz_ko.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
class BlatzKo(BaseConstitutiveModel):
shear_modulus: Property

def energy(self, F):
def energy(self, grad_u):
# unpack properties
G = self.shear_modulus

# kinematics
I2 = self.I2(F)
I3 = jnp.linalg.det(F.T @ F)
I2 = self.I2(grad_u)
I3 = self.I3(grad_u)

Check warning on line 15 in pancax/constitutive_models/blatz_ko.py

View check run for this annotation

Codecov / codecov/patch

pancax/constitutive_models/blatz_ko.py#L14-L15

Added lines #L14 - L15 were not covered by tests

# constitutive
W = (G / 2.) * (I2 / I3 + 2 * jnp.sqrt(I3) - 5.)
Expand Down
7 changes: 3 additions & 4 deletions pancax/constitutive_models/gent.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,13 @@ class Gent(BaseConstitutiveModel):
shear_modulus: Property
Jm_parameter: Property

def energy(self, F):
def energy(self, grad_u):
# unpack properties
K, G, Jm = self.bulk_modulus, self.shear_modulus, self.Jm_parameter

# kinematics
C = F.T @ F
J = self.jacobian(F)
I_1_bar = jnp.trace(jnp.power(J, -2. / 3.) * C)
J = self.jacobian(grad_u)
I_1_bar = self.I1_bar(grad_u)

# guard rail
check_value = I_1_bar > Jm + 3.0 - 0.001
Expand Down
8 changes: 3 additions & 5 deletions pancax/constitutive_models/neohookean.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,12 @@ class NeoHookean(BaseConstitutiveModel):
bulk_modulus: Property
shear_modulus: Property

def energy(self, F: Tensor) -> Scalar:
def energy(self, grad_u: Tensor) -> Scalar:
K, G = self.bulk_modulus, self.shear_modulus

# kinematics
# F = jnp.eye(3) + grad_u
C = F.T @ F
J = self.jacobian(F)
I_1_bar = jnp.trace(1. / jnp.square(jnp.cbrt(J)) * C)
J = self.jacobian(grad_u)
I_1_bar = self.I1_bar(grad_u)

# constitutive
W_vol = 0.5 * K * (0.5 * (J**2 - 1) - jnp.log(J))
Expand Down
13 changes: 8 additions & 5 deletions pancax/constitutive_models/swanson.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,29 +24,32 @@ class Swanson(BaseConstitutiveModel):
# hack because Swanson is a stupid model
cutoff_strain: float = eqx.field(static=True)

def energy(self, F):
def energy(self, grad_u):
K = self.bulk_modulus
A1, P1 = self.A1, self.P1
B1, Q1 = self.B1, self.Q1
C1, R1 = self.C1, self.R1
tau_cutoff = (1. / 3.) * (3. + self.cutoff_strain**2) - 1.

# kinematics
J = self.jacobian(F)
C = F.T @ F
C_bar = jnp.power(J, -2. / 3.) * C
I_1_bar = jnp.trace(C_bar)
J = self.jacobian(grad_u)
I_1_bar = self.I1_bar(grad_u)
I_2_bar = self.I2_bar(grad_u)
tau_1 = (1. / 3.) * I_1_bar - 1.
tau_2 = (1. / 3.) * I_2_bar - 1.
tau_tilde_1 = tau_1 + tau_cutoff
tau_tilde_2 = tau_2 + tau_cutoff

# constitutive
W_vol = K * (J * jnp.log(J) - J + 1.)
W_dev_tau = 3. / 2. * (
A1 / (P1 + 1.) * (tau_tilde_1**(P1 + 1.)) +
B1 / (Q1 + 1.) * (tau_tilde_2**(Q1 + 1.)) +
C1 / (R1 + 1.) * (tau_tilde_1**(R1 + 1.))
)
W_dev_cutoff = 3. / 2. * (
A1 / (P1 + 1.) * (tau_cutoff**(P1 + 1.)) +
B1 / (Q1 + 1.) * (tau_cutoff**(Q1 + 1.)) +
C1 / (R1 + 1.) * (tau_cutoff**(R1 + 1.))
)
W_dev = W_dev_tau - W_dev_cutoff
Expand Down
6 changes: 1 addition & 5 deletions pancax/physics_kernels/solid_mechanics.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,5 @@ def __init__(self, constitutive_model, formulation) -> None:
self.formulation = formulation

def energy(self, params, x, t, u, grad_u, *args):
# _, model = params
grad_u = self.formulation.modify_field_gradient(grad_u)
F = grad_u + jnp.eye(3)
return self.constitutive_model.energy(F)
# return self.constitutive_model.energy(grad_u)
# return model.energy(grad_u)
return self.constitutive_model.energy(grad_u)

Check warning on line 72 in pancax/physics_kernels/solid_mechanics.py

View check run for this annotation

Codecov / codecov/patch

pancax/physics_kernels/solid_mechanics.py#L72

Added line #L72 was not covered by tests
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ authors = [
]
dependencies = [
'matplotlib',
'meshio',
'netCDF4',
'pandas',
'scipy',
Expand Down
13 changes: 4 additions & 9 deletions test/constitutive_models/test_base_constitutive_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
# from pancax import NeoHookean, NeoHookeanFixedBulkModulus
from pancax import NeoHookean
# import jax
import jax.numpy as jnp


Expand All @@ -15,7 +13,8 @@ def test_jacobian():
[0., 2., 0.],
[0., 0., 1.]
])
J = model.jacobian(F)
grad_u = F - jnp.eye(3)
J = model.jacobian(grad_u)
assert jnp.array_equal(J, jnp.linalg.det(F))

# TODO add better test.
Expand All @@ -31,10 +30,6 @@ def test_jacobian_bad_value():
[0., 2., 0.],
[0., 0., -1.]
])
J = model.jacobian(F)
grad_u = F - jnp.eye(3)
J = model.jacobian(grad_u)
assert jnp.array_equal(J, 1.e3)


# def test_bulk_modulus_init():
# # model = NeoHookeanFixedBulkModulus(bulk_modulus=10.0)
# assert model.bulk_modulus == 10.0
19 changes: 10 additions & 9 deletions test/constitutive_models/test_gent.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# from pancax import FixedProperties, Gent, GentFixedBulkModulus
from pancax import BoundedProperty, Gent
from .utils import *
import jax
Expand Down Expand Up @@ -33,10 +32,11 @@ def gent_2():
def simple_shear_test(model):
gammas = jnp.linspace(0.0, 1., 100)
Fs = jax.vmap(simple_shear)(gammas)
Js = jax.vmap(model.jacobian)(Fs)
I1_bars = jax.vmap(model.I1_bar)(Fs)
psis = jax.vmap(model.energy, in_axes=(0,))(Fs)
sigmas = jax.vmap(model.cauchy_stress, in_axes=(0,))(Fs)
grad_us = jax.vmap(lambda F: F - jnp.eye(3))(Fs)
Js = jax.vmap(model.jacobian)(grad_us)
I1_bars = jax.vmap(model.I1_bar)(grad_us)
psis = jax.vmap(model.energy, in_axes=(0,))(grad_us)
sigmas = jax.vmap(model.cauchy_stress, in_axes=(0,))(grad_us)

for (psi, sigma, gamma, I1_bar, J) in zip(psis, sigmas, gammas, I1_bars, Js):
psi_an = 0.5 * K * (0.5 * (J**2 - 1) - jnp.log(J)) + \
Expand All @@ -61,10 +61,11 @@ def simple_shear_test(model):
def uniaxial_strain_test(model):
lambdas = jnp.linspace(1., 2., 100)
Fs = jax.vmap(uniaxial_strain)(lambdas)
Js = jax.vmap(model.jacobian)(Fs)
I1_bars = jax.vmap(model.I1_bar)(Fs)
psis = jax.vmap(model.energy, in_axes=(0,))(Fs)
sigmas = jax.vmap(model.cauchy_stress, in_axes=(0,))(Fs)
grad_us = jax.vmap(lambda F: F - jnp.eye(3))(Fs)
Js = jax.vmap(model.jacobian)(grad_us)
I1_bars = jax.vmap(model.I1_bar)(grad_us)
psis = jax.vmap(model.energy, in_axes=(0,))(grad_us)
sigmas = jax.vmap(model.cauchy_stress, in_axes=(0,))(grad_us)

for (psi, sigma, lambda_, I1_bar, J) in zip(psis, sigmas, lambdas, I1_bars, Js):
psi_an = 0.5 * K * (0.5 * (J**2 - 1) - jnp.log(J)) + \
Expand Down
19 changes: 10 additions & 9 deletions test/constitutive_models/test_neohookean.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,11 @@ def neohookean_2():
def simple_shear_test(model):
gammas = jnp.linspace(0.0, 1., 100)
Fs = jax.vmap(simple_shear)(gammas)
Js = jax.vmap(model.jacobian)(Fs)
I1_bars = jax.vmap(model.I1_bar)(Fs)
psis = jax.vmap(model.energy, in_axes=(0,))(Fs)#, props())
sigmas = jax.vmap(model.cauchy_stress, in_axes=(0,))(Fs)#, props())
grad_us = jax.vmap(lambda F: F - jnp.eye(3))(Fs)
Js = jax.vmap(model.jacobian)(grad_us)
I1_bars = jax.vmap(model.I1_bar)(grad_us)
psis = jax.vmap(model.energy, in_axes=(0,))(grad_us)
sigmas = jax.vmap(model.cauchy_stress, in_axes=(0,))(grad_us)

for (psi, sigma, gamma, I1_bar, J) in zip(psis, sigmas, gammas, I1_bars, Js):
psi_an = 0.5 * K * (0.5 * (J**2 - 1) - jnp.log(J)) + \
Expand All @@ -57,11 +58,11 @@ def simple_shear_test(model):
def uniaxial_strain_test(model):
lambdas = jnp.linspace(1., 4., 100)
Fs = jax.vmap(uniaxial_strain)(lambdas)
Js = jax.vmap(model.jacobian)(Fs)
I1_bars = jax.vmap(model.I1_bar)(Fs)
psis = jax.vmap(model.energy, in_axes=(0,))(Fs)#, props())
sigmas = jax.vmap(model.cauchy_stress, in_axes=(0,))(Fs)#, props())
# K, G = model.unpack_properties(props())
grad_us = jax.vmap(lambda F: F - jnp.eye(3))(Fs)
Js = jax.vmap(model.jacobian)(grad_us)
I1_bars = jax.vmap(model.I1_bar)(grad_us)
psis = jax.vmap(model.energy, in_axes=(0,))(grad_us)
sigmas = jax.vmap(model.cauchy_stress, in_axes=(0,))(grad_us)

for (psi, sigma, lambda_, I1_bar, J) in zip(psis, sigmas, lambdas, I1_bars, Js):
psi_an = 0.5 * K * (0.5 * (J**2 - 1) - jnp.log(J)) + \
Expand Down
18 changes: 10 additions & 8 deletions test/constitutive_models/test_swanson.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,12 @@ def swanson_2():
def simple_shear_test(model):
gammas = jnp.linspace(0.05, 1., 100)
Fs = jax.vmap(simple_shear)(gammas)
grad_us = jax.vmap(lambda F: F - jnp.eye(3))(Fs)
B_bars = jax.vmap(lambda F: jnp.power(jnp.linalg.det(F), -2. / 3.) * F @ F.T)(Fs)
Js = jax.vmap(model.jacobian)(Fs)
I1_bars = jax.vmap(model.I1_bar)(Fs)
psis = jax.vmap(model.energy, in_axes=(0,))(Fs)
sigmas = jax.vmap(model.cauchy_stress, in_axes=(0,))(Fs)
Js = jax.vmap(model.jacobian)(grad_us)
I1_bars = jax.vmap(model.I1_bar)(grad_us)
psis = jax.vmap(model.energy, in_axes=(0,))(grad_us)
sigmas = jax.vmap(model.cauchy_stress, in_axes=(0,))(grad_us)

for (psi, sigma, I1_bar, J, B_bar) in zip(psis, sigmas, I1_bars, Js, B_bars):
psi_an = K * (J * jnp.log(J) - J + 1.) + \
Expand All @@ -70,11 +71,12 @@ def simple_shear_test(model):
def uniaxial_strain_test(model):
lambdas = jnp.linspace(1.2, 4., 100)
Fs = jax.vmap(uniaxial_strain)(lambdas)
grad_us = jax.vmap(lambda F: F - jnp.eye(3))(Fs)
B_bars = jax.vmap(lambda F: jnp.power(jnp.linalg.det(F), -2. / 3.) * F @ F.T)(Fs)
Js = jax.vmap(model.jacobian)(Fs)
I1_bars = jax.vmap(model.I1_bar)(Fs)
psis = jax.vmap(model.energy, in_axes=(0,))(Fs)
sigmas = jax.vmap(model.cauchy_stress, in_axes=(0,))(Fs)
Js = jax.vmap(model.jacobian)(grad_us)
I1_bars = jax.vmap(model.I1_bar)(grad_us)
psis = jax.vmap(model.energy, in_axes=(0,))(grad_us)
sigmas = jax.vmap(model.cauchy_stress, in_axes=(0,))(grad_us)

for (psi, sigma, I1_bar, J, B_bar) in zip(psis, sigmas, I1_bars, Js, B_bars):
psi_an = K * (J * jnp.log(J) - J + 1.) + \
Expand Down
Loading

0 comments on commit 7d1f59a

Please sign in to comment.