Skip to content

Commit

Permalink
organizing constitutive models based on types and moving all interfac…
Browse files Browse the repository at this point in the history
…es from F to grad u so it will be consistent across all physics.
  • Loading branch information
cmhamel committed Mar 2, 2025
1 parent 7d1f59a commit d500fb9
Show file tree
Hide file tree
Showing 8 changed files with 128 additions and 120 deletions.
10 changes: 5 additions & 5 deletions pancax/constitutive_models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from .base import BaseConstitutiveModel
from .base import ConstitutiveModel
from .properties import BoundedProperty, FixedProperty, Property

# models
from .blatz_ko import BlatzKo
from .gent import Gent
from .neohookean import NeoHookean
from .swanson import Swanson
from .mechanics.hyperelasticity.blatz_ko import BlatzKo
from .mechanics.hyperelasticity.gent import Gent
from .mechanics.hyperelasticity.neohookean import NeoHookean
from .mechanics.hyperelasticity.swanson import Swanson
105 changes: 2 additions & 103 deletions pancax/constitutive_models/base.py
Original file line number Diff line number Diff line change
@@ -1,114 +1,13 @@
from abc import abstractmethod
from jaxtyping import Array, Float
from typing import Tuple
import equinox as eqx
import jax
import jax.numpy as jnp


Scalar = float
State = Float[Array, "ns"]
Tensor = Float[Array, "3 3"]
Vector = Float[Array, "3"]


class BaseConstitutiveModel(eqx.Module):
def cauchy_stress(self, grad_u: Tensor) -> Tensor:
F = grad_u + jnp.eye(3)
J = self.jacobian(grad_u)
P = self.pk1_stress(grad_u)
return (1. / J) * P @ F.T

def deformation_gradient(self, grad_u: Tensor) -> Tensor:
F = grad_u + jnp.eye(3)
return F

@abstractmethod
def energy(self, grad_u: Tensor) -> Scalar:
"""
This method returns the algorithmic strain energy density.
"""
pass

def invariants(self, grad_u: Tensor) -> Tuple[Scalar, Scalar, Scalar]:
I1 = self.I1(grad_u)
I2 = self.I2(grad_u)
I3 = self.I3(grad_u)
return jnp.array([I1, I2, I3])

def I1(self, grad_u: Tensor) -> Scalar:
r"""
Calculates the first invariant
- **grad_u**: the displacement gradient
$$
I_1 = tr\left(\mathbf{F}^T\mathbf{F}\right)
$$
"""
F = self.deformation_gradient(grad_u)
I1 = jnp.trace(F @ F.T)
return I1

def I1_bar(self, grad_u: Tensor) -> Scalar:
r"""
Calculates the first distortional invariant
- **grad_u**: the displacement gradient
$$
\bar{I}_1 = J^{-2/3}tr\left(\mathbf{F}^T\mathbf{F}\right)
$$
"""
F = self.deformation_gradient(grad_u)
I1 = jnp.trace(F @ F.T)
J = self.jacobian(grad_u)
return jnp.power(J, -2. / 3.) * I1

def I2(self, grad_u: Tensor) -> Scalar:
F = self.deformation_gradient(grad_u)
C = F.T @ F
C2 = C @ C
I1 = jnp.trace(C)
I2 = 0.5 * (I1**2 - jnp.trace(C2))
return I2

def I2_bar(self, grad_u: Tensor) -> Scalar:
F = self.deformation_gradient(grad_u)
C = F.T @ F
C2 = C @ C
I1 = jnp.trace(C)
I2 = 0.5 * (I1**2 - jnp.trace(C2))
J = self.jacobian(grad_u)
return jnp.power(J, -4. / 3.) * I2

def I3(self, grad_u: Tensor) -> Scalar:
J = self.jacobian(grad_u)
return J * J

def jacobian(self, grad_u: Tensor) -> Scalar:
r"""
This simply calculate the jacobian but with guard rails
to return nonsensical numbers if a non-positive jacobian
is encountered during training.
- **grad_u**: the displacement gradient
$$
J = det(\mathbf{F})
$$
"""
F = self.deformation_gradient(grad_u)
J = jnp.linalg.det(F)
J = jax.lax.cond(
J <= 0.0,
lambda _: 1.0e3,
lambda x: x,
J
)
return J

def pk1_stress(self, grad_u: Tensor) -> Tensor:
return jax.grad(self.energy, argnums=0)(grad_u)

class ConstitutiveModel(eqx.Module):
def properties(self):
return self.__dataclass_fields__
Empty file.
109 changes: 109 additions & 0 deletions pancax/constitutive_models/mechanics/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
from abc import abstractmethod
from ..base import ConstitutiveModel, Scalar, Tensor
from typing import Tuple
import jax
import jax.numpy as jnp


class MechanicsModel(ConstitutiveModel):
def cauchy_stress(self, grad_u: Tensor, *args) -> Tensor:
F = self.deformation_gradient(grad_u)
J = self.jacobian(grad_u)
P = self.pk1_stress(grad_u, *args)
return (1. / J) * P @ F.T

def deformation_gradient(self, grad_u: Tensor) -> Tensor:
F = grad_u + jnp.eye(3)
return F

@abstractmethod
def energy(self, grad_u: Tensor, *args) -> Scalar:
"""
This method returns the algorithmic strain energy density.
"""
pass

def invariants(self, grad_u: Tensor) -> Tuple[Scalar, Scalar, Scalar]:
I1 = self.I1(grad_u)
I2 = self.I2(grad_u)
I3 = self.I3(grad_u)
return jnp.array([I1, I2, I3])

Check warning on line 30 in pancax/constitutive_models/mechanics/base.py

View check run for this annotation

Codecov / codecov/patch

pancax/constitutive_models/mechanics/base.py#L27-L30

Added lines #L27 - L30 were not covered by tests

def I1(self, grad_u: Tensor) -> Scalar:
r"""
Calculates the first invariant
- **grad_u**: the displacement gradient
$$
I_1 = tr\left(\mathbf{F}^T\mathbf{F}\right)
$$
"""
F = self.deformation_gradient(grad_u)
I1 = jnp.trace(F @ F.T)
return I1

Check warning on line 44 in pancax/constitutive_models/mechanics/base.py

View check run for this annotation

Codecov / codecov/patch

pancax/constitutive_models/mechanics/base.py#L42-L44

Added lines #L42 - L44 were not covered by tests

def I1_bar(self, grad_u: Tensor) -> Scalar:
r"""
Calculates the first distortional invariant
- **grad_u**: the displacement gradient
$$
\bar{I}_1 = J^{-2/3}tr\left(\mathbf{F}^T\mathbf{F}\right)
$$
"""
F = self.deformation_gradient(grad_u)
I1 = jnp.trace(F @ F.T)
J = self.jacobian(grad_u)
return jnp.power(J, -2. / 3.) * I1

def I2(self, grad_u: Tensor) -> Scalar:
F = self.deformation_gradient(grad_u)
C = F.T @ F
C2 = C @ C
I1 = jnp.trace(C)
I2 = 0.5 * (I1**2 - jnp.trace(C2))
return I2

Check warning on line 67 in pancax/constitutive_models/mechanics/base.py

View check run for this annotation

Codecov / codecov/patch

pancax/constitutive_models/mechanics/base.py#L62-L67

Added lines #L62 - L67 were not covered by tests

def I2_bar(self, grad_u: Tensor) -> Scalar:
F = self.deformation_gradient(grad_u)
C = F.T @ F
C2 = C @ C
I1 = jnp.trace(C)
I2 = 0.5 * (I1**2 - jnp.trace(C2))
J = self.jacobian(grad_u)
return jnp.power(J, -4. / 3.) * I2

def I3(self, grad_u: Tensor) -> Scalar:
J = self.jacobian(grad_u)
return J * J

Check warning on line 80 in pancax/constitutive_models/mechanics/base.py

View check run for this annotation

Codecov / codecov/patch

pancax/constitutive_models/mechanics/base.py#L79-L80

Added lines #L79 - L80 were not covered by tests

def jacobian(self, grad_u: Tensor) -> Scalar:
r"""
This simply calculate the jacobian but with guard rails
to return nonsensical numbers if a non-positive jacobian
is encountered during training.
- **grad_u**: the displacement gradient
$$
J = det(\mathbf{F})
$$
"""
F = self.deformation_gradient(grad_u)
J = jnp.linalg.det(F)
J = jax.lax.cond(
J <= 0.0,
lambda _: 1.0e3,
lambda x: x,
J
)
return J

def pk1_stress(self, grad_u: Tensor, *args) -> Tensor:
return jax.grad(self.energy, argnums=0)(grad_u, *args)


class HyperelasticModel(MechanicsModel):
pass
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from .base import BaseConstitutiveModel
from .properties import Property
from ..base import HyperelasticModel
from ...properties import Property
import jax.numpy as jnp


class BlatzKo(BaseConstitutiveModel):
class BlatzKo(HyperelasticModel):
shear_modulus: Property

def energy(self, grad_u):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from .base import BaseConstitutiveModel
from .properties import Property
from ..base import HyperelasticModel
from ...properties import Property
import jax
import jax.numpy as jnp


class Gent(BaseConstitutiveModel):
class Gent(HyperelasticModel):
r"""
Gent model with the following model form
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from .base import BaseConstitutiveModel, Scalar, Tensor
from .properties import Property
from ..base import HyperelasticModel, Scalar, Tensor
from ...properties import Property
import jax.numpy as jnp


class NeoHookean(BaseConstitutiveModel):
class NeoHookean(HyperelasticModel):
r"""
NeoHookean model with the following model form
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from .base import BaseConstitutiveModel
from .properties import Property
from ..base import HyperelasticModel
from ...properties import Property
import equinox as eqx
import jax.numpy as jnp


class Swanson(BaseConstitutiveModel):
class Swanson(HyperelasticModel):
r"""
Swanson model truncated to 4 parameters
Expand Down

0 comments on commit d500fb9

Please sign in to comment.