Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add HuggingFace SafeTensors to MLM Artifact Types Best Practices #71

Merged
merged 4 commits into from
Feb 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased](https://github.com/stac-extensions/mlm/tree/main)

### Added
- Add [`huggingface/safetensors`](https://github.com/huggingface/safetensors)
recommendations for `mlm:artifact_type` and corresponding ``mlm:framework`` values
(fixes [#68](https://github.com/stac-extensions/mlm/issues/68)).
- Add [`Flax`](https://github.com/google/flax) to the list of `mlm:framework` and
the corresponding `mlm:artifact_type` SafeTensors backend in the JSON schema examples.
- Add [`Paddle`](https://github.com/PaddlePaddle/Paddle) to the list of `mlm:framework`
(fixes [#69](https://github.com/stac-extensions/mlm/issues/69)).

Expand Down
19 changes: 10 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ The fields in the table below can be used in these parts of STAC documents:
| mlm:name <sup>[\[1\]][1]</sup> | string | **REQUIRED** A name for the model. This can include, but must be distinct, from simply naming the model architecture. If there is a publication or other published work related to the model, use the official name of the model. |
| mlm:architecture | [Model Architecture](#model-architecture) string | **REQUIRED** A generic and well established architecture name of the model. |
| mlm:tasks | \[[Task Enum](#task-enum)] | **REQUIRED** Specifies the Machine Learning tasks for which the model can be used for. If multi-tasks outputs are provided by distinct model heads, specify all available tasks under the main properties and specify respective tasks in each [Model Output Object](#model-output-object). |
| mlm:framework | string | Framework used to train the model (ex: PyTorch, TensorFlow). |
| mlm:framework | string | Framework used to train the model (ex: PyTorch, TensorFlow). Typically, this will align with the applied `mlm:artifact_type` of the [Model Asset](#model-asset). |
| mlm:framework_version | string | The `framework` library version. Some models require a specific version of the machine learning `framework` to run. |
| mlm:memory_size | integer | The in-memory size of the model on the accelerator during inference (bytes). |
| mlm:total_parameters | integer | Total number of model parameters, including trainable and non-trainable parameters. |
Expand Down Expand Up @@ -238,6 +238,7 @@ to use common names when applicable. Below are a few notable entries.
- [`rgee`](https://github.com/r-spatial/rgee)
- [`spatialRF`](https://github.com/BlasBenito/spatialRF)
- [`JAX`](https://github.com/jax-ml/jax)
- [`Flax`](https://github.com/google/flax)
- [`MXNet`](https://github.com/apache/mxnet)
- [`Caffe`](https://github.com/BVLC/caffe)
- [`PyMC`](https://github.com/pymc-devs/pymc)
Expand Down Expand Up @@ -662,14 +663,14 @@ In order to provide more context, the following roles are also recommended were

### Model Asset

| Field Name | Type | Description |
|-------------------|---------------------------------|--------------------------------------------------------------------------------------------------|
| title | string | Description of the model asset. |
| href | string | URI to the model artifact. |
| type | string | The media type of the artifact (see [Model Artifact Media-Type](#model-artifact-media-type). |
| roles | \[string] | **REQUIRED** Specify `mlm:model`. Can include `["mlm:weights", "mlm:checkpoint"]` as applicable. |
| mlm:artifact_type | [Artifact Type](./best-practices.md#framework-specific-artifact-types) | Specifies the kind of model artifact, any string is allowed. Typically related to a particular ML framework, see [Best Practices - Framework Specific Artifact Types](./best-practices.md#framework-specific-artifact-types) for **RECOMMENDED** values. This field is **REQUIRED** if the `mlm:model` role is specified. |
| mlm:compile_method | [Compile Method](#compile-method) \| null | Describes the method used to compile the ML model either when the model is saved or at model runtime prior to inference. |
| Field Name | Type | Description |
|--------------------|------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| title | string | Description of the model asset. |
| href | string | URI to the model artifact. |
| type | string | The media type of the artifact (see [Model Artifact Media-Type](#model-artifact-media-type). |
| roles | \[string] | **REQUIRED** Specify `mlm:model`. Can include `["mlm:weights", "mlm:checkpoint"]` as applicable. |
| mlm:artifact_type | [Artifact Type](./best-practices.md#framework-specific-artifact-types) | Specifies the kind of model artifact, any string is allowed. Typically related to a particular ML framework, see [Best Practices - Framework Specific Artifact Types](./best-practices.md#framework-specific-artifact-types) for **RECOMMENDED** values. This field is **REQUIRED** if the `mlm:model` role is specified. |
| mlm:compile_method | [Compile Method](#compile-method) \| null | Describes the method used to compile the ML model either when the model is saved or at model runtime prior to inference. |

Recommended Asset `roles` include `mlm:weights` or `mlm:checkpoint` for model weights that need to be loaded by a
model definition and `mlm:compiled` for models that can be loaded directly without an intermediate model definition.
Expand Down
21 changes: 13 additions & 8 deletions best-practices.md
Original file line number Diff line number Diff line change
Expand Up @@ -301,14 +301,15 @@ permitted, as these values are not validated by the schema. Note that the names
framework-specific definitions to help the users understand how the model artifact was created, although these exact
names are not strictly required either.

| Artifact Type | Description |
|--------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| `torch.save` | A [serialized python pickle object][pytorch-save] (i.e.: `.pt`) which can represent a model or state_dict. |
| `torch.jit.save` | A [`TorchScript`][pytorch-jit-script] model artifact obtained with one or more of the graph export options Torchscript Tracing and Torchscript Scripting. |
| `torch.export.save` | A model artifact storing an [ExportedProgram][exported-program] obtained by [`torch.export.export`][pytorch-export] (i.e.: `.pt2`). |
| `tf.keras.Model.save` | Saves a [.keras model file][keras-model], a unified zip archive format containing the architecture, weights, optimizer, losses, and metrics. |
| `tf.keras.Model.save_weights` | A [.weights.h5][keras-save-weights] file containing only model weights for use by Tensorflow or Keras. |
| `tf.keras.Model.export` | [TF Saved Model][tf-saved-model] is the [recommended format][tf-keras-recommended] by the Tensorflow team for whole model saving/loading for inference. See the docs for [different save methods][keras-methods] in TF and Keras. |
| Artifact Type | Description |
|------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| `torch.save` | A [serialized python pickle object][pytorch-save] (i.e.: `.pt`) which can represent a model or state_dict. |
| `torch.jit.save` | A [`TorchScript`][pytorch-jit-script] model artifact obtained with one or more of the graph export options TorchScript Tracing and Scripting. |
| `torch.export.save` | A model artifact storing an [ExportedProgram][exported-program] obtained by [`torch.export.export`][pytorch-export] (i.e.: `.pt2`). |
| `tf.keras.Model.save` | Saves a [.keras model file][keras-model], a unified zip archive format containing the architecture, weights, optimizer, losses, and metrics. |
| `tf.keras.Model.save_weights` | A [.weights.h5][keras-save-weights] file containing only model weights for use by Tensorflow or Keras. |
| `tf.keras.Model.export` | [TF Saved Model][tf-saved-model] is the [recommended format][tf-keras-recommended] by the Tensorflow team for whole model saving/loading for inference. See the docs for [different save methods][keras-methods] in TF and Keras. |
| `safetensors.{framework}.{method}` | Model weights saved as [HuggingFace SafeTensors][hf-st], where `{framework}` matches the [`mlm:framework`][mlm-framework] of a [*supported framework*][hf-st-support] and `{method}` matches the applicable method from SafeTensors. For example, a PyTorch model saved this way would indicate [`safetensors.torch.save_file`][hf-st-torch]. |

[exported-program]: https://pytorch.org/docs/main/export.html#serialization
[pytorch-aot-inductor]: https://pytorch.org/docs/main/torch.compiler_aot_inductor.html
Expand All @@ -321,3 +322,7 @@ names are not strictly required either.
[tf-keras-recommended]: https://www.tensorflow.org/guide/saved_model#creating_a_savedmodel_from_keras
[keras-methods]: https://keras.io/2.16/api/models/model_saving_apis/
[keras-model]: https://keras.io/api/models/model_saving_apis/model_saving_and_loading/
[hf-st]: https://github.com/huggingface/safetensors
[hf-st-support]: https://huggingface.co/docs/safetensors/index#featured-projects
[hf-st-torch]: https://huggingface.co/docs/safetensors/api/torch#safetensors.torch.save_file
[mlm-framework]: README.md#item-properties-and-collection-fields
7 changes: 6 additions & 1 deletion json-schema/schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,7 @@
"rgee",
"spatialRF",
"JAX",
"Flax",
"MXNet",
"Caffe",
"PyMC",
Expand Down Expand Up @@ -465,7 +466,11 @@
"torch.export.save",
"tf.keras.Model.save",
"tf.keras.Model.save_weights",
"tf.saved_model.export(format='tf_saved_model')"
"tf.keras.Model.export",
"safetensors.torch.save_file",
"safetensors.tensorflow.save_file",
"safetensors.flax.save_file",
"safetensors.paddle.save_file"
]
},
"mlm:compile_method": {
Expand Down