Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve pickle (CandleTensor) conversions to NestedValue #1944

Merged
merged 2 commits into from
Jul 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions crates/burn-core/src/record/serde/ser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use serde::{
/// NOTE: This is used to serialize Param structs into NestedValues and not so much for
/// the actual serialization of modules (although it could be used for that as well if all
/// primitive types are implemented).
#[derive(Clone)]
pub struct Serializer {
/// The state of the serialization process
state: Option<NestedValue>,
Expand Down
23 changes: 21 additions & 2 deletions crates/burn-import/src/pytorch/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use super::{adapter::PyTorchAdapter, error::Error};

use burn::{
module::ParamId,
record::{ParamSerde, PrecisionSettings},
record::PrecisionSettings,
tensor::{Element, ElementConversion, TensorData},
};
use burn::{
Expand Down Expand Up @@ -141,7 +141,26 @@ where
.map(ElementConversion::elem)
.collect();

ParamSerde::new(param_id, TensorData::new(data, shape)).serialize(serializer)
let TensorData {
bytes,
shape,
dtype,
} = TensorData::new(data, shape);

// Manually serialize the tensor instead of using the `ParamSerde` struct, such as:
// ParamSerde::new(param_id, TensorData::new(data, shape)).serialize(serializer)
// Because serializer copies individual elements of TensorData `value` into a new Vec<u8>,
// which is not necessary and inefficient.
let mut tensor_data: HashMap<String, NestedValue> = HashMap::new();
tensor_data.insert("bytes".into(), NestedValue::U8s(bytes));
tensor_data.insert("shape".into(), shape.serialize(serializer.clone())?);
tensor_data.insert("dtype".into(), dtype.serialize(serializer)?);

let mut param: HashMap<String, NestedValue> = HashMap::new();
param.insert("id".into(), NestedValue::String(param_id));
param.insert("param".into(), NestedValue::Map(tensor_data));

Ok(NestedValue::Map(param))
}

/// New type struct for Candle tensors because we need to implement the `Serializable` trait for it.
Expand Down
38 changes: 19 additions & 19 deletions crates/burn-tensor/src/tensor/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ pub enum DataError {
#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub struct TensorData {
/// The values of the tensor (as bytes).
value: Vec<u8>,
pub bytes: Vec<u8>,

/// The shape of the tensor.
pub shape: Vec<usize>,
Expand All @@ -42,7 +42,7 @@ impl TensorData {
/// Creates a new tensor data structure.
pub fn new<E: Element, S: Into<Vec<usize>>>(value: Vec<E>, shape: S) -> Self {
Self {
value: bytemuck::checked::cast_slice(&value).to_vec(),
bytes: bytemuck::checked::cast_slice(&value).to_vec(),
shape: shape.into(),
dtype: E::dtype(),
}
Expand All @@ -51,7 +51,7 @@ impl TensorData {
/// Returns the immutable slice view of the tensor data.
pub fn as_slice<E: Element>(&self) -> Result<&[E], DataError> {
if E::dtype() == self.dtype {
bytemuck::checked::try_cast_slice(&self.value).map_err(DataError::CastError)
bytemuck::checked::try_cast_slice(&self.bytes).map_err(DataError::CastError)
} else {
Err(DataError::TypeMismatch(format!(
"Invalid target element type (expected {:?}, got {:?})",
Expand All @@ -67,7 +67,7 @@ impl TensorData {
/// If the target element type is different from the stored element type.
pub fn as_mut_slice<E: Element>(&mut self) -> Result<&mut [E], DataError> {
if E::dtype() == self.dtype {
bytemuck::checked::try_cast_slice_mut(&mut self.value).map_err(DataError::CastError)
bytemuck::checked::try_cast_slice_mut(&mut self.bytes).map_err(DataError::CastError)
} else {
Err(DataError::TypeMismatch(format!(
"Invalid target element type (expected {:?}, got {:?})",
Expand All @@ -85,62 +85,62 @@ impl TensorData {
/// Returns an iterator over the values of the tensor data.
pub fn iter<E: Element>(&self) -> Box<dyn Iterator<Item = E> + '_> {
if E::dtype() == self.dtype {
Box::new(bytemuck::checked::cast_slice(&self.value).iter().copied())
Box::new(bytemuck::checked::cast_slice(&self.bytes).iter().copied())
} else {
match self.dtype {
DType::I8 => Box::new(
bytemuck::checked::cast_slice(&self.value)
bytemuck::checked::cast_slice(&self.bytes)
.iter()
.map(|e: &i8| e.elem::<E>()),
),
DType::I16 => Box::new(
bytemuck::checked::cast_slice(&self.value)
bytemuck::checked::cast_slice(&self.bytes)
.iter()
.map(|e: &i16| e.elem::<E>()),
),
DType::I32 => Box::new(
bytemuck::checked::cast_slice(&self.value)
bytemuck::checked::cast_slice(&self.bytes)
.iter()
.map(|e: &i32| e.elem::<E>()),
),
DType::I64 => Box::new(
bytemuck::checked::cast_slice(&self.value)
bytemuck::checked::cast_slice(&self.bytes)
.iter()
.map(|e: &i64| e.elem::<E>()),
),
DType::U8 => Box::new(self.value.iter().map(|e| e.elem::<E>())),
DType::U8 => Box::new(self.bytes.iter().map(|e| e.elem::<E>())),
DType::U32 => Box::new(
bytemuck::checked::cast_slice(&self.value)
bytemuck::checked::cast_slice(&self.bytes)
.iter()
.map(|e: &u32| e.elem::<E>()),
),
DType::U64 => Box::new(
bytemuck::checked::cast_slice(&self.value)
bytemuck::checked::cast_slice(&self.bytes)
.iter()
.map(|e: &u64| e.elem::<E>()),
),
DType::BF16 => Box::new(
bytemuck::checked::cast_slice(&self.value)
bytemuck::checked::cast_slice(&self.bytes)
.iter()
.map(|e: &bf16| e.elem::<E>()),
),
DType::F16 => Box::new(
bytemuck::checked::cast_slice(&self.value)
bytemuck::checked::cast_slice(&self.bytes)
.iter()
.map(|e: &f16| e.elem::<E>()),
),
DType::F32 => Box::new(
bytemuck::checked::cast_slice(&self.value)
bytemuck::checked::cast_slice(&self.bytes)
.iter()
.map(|e: &f32| e.elem::<E>()),
),
DType::F64 => Box::new(
bytemuck::checked::cast_slice(&self.value)
bytemuck::checked::cast_slice(&self.bytes)
.iter()
.map(|e: &f64| e.elem::<E>()),
),
// bool is a byte value equal to either 0 or 1
DType::Bool => Box::new(self.value.iter().map(|e| e.elem::<E>())),
DType::Bool => Box::new(self.bytes.iter().map(|e| e.elem::<E>())),
}
}
}
Expand Down Expand Up @@ -220,7 +220,7 @@ impl TensorData {

/// Returns the data as a slice of bytes.
pub fn as_bytes(&self) -> &[u8] {
self.value.as_slice()
self.bytes.as_slice()
}

/// Asserts the data is approximately equal to another data.
Expand Down Expand Up @@ -891,7 +891,7 @@ mod tests {
&mut StdRng::from_entropy(),
);

assert_eq!(num_elements, data.value.len() / 4); // f32 stored as u8s
assert_eq!(num_elements, data.bytes.len() / 4); // f32 stored as u8s
assert_eq!(num_elements, data.as_slice::<f32>().unwrap().len());
}

Expand Down
Loading