Skip to content

Commit

Permalink
feat(starknet_os): stateless compression util
Browse files Browse the repository at this point in the history
  • Loading branch information
yoavGrs committed Mar 4, 2025
1 parent a57860d commit 13fbe1d
Show file tree
Hide file tree
Showing 2 changed files with 202 additions and 1 deletion.
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::cmp::min;
use std::collections::HashSet;

use assert_matches::assert_matches;
use num_bigint::BigUint;
Expand All @@ -8,6 +9,8 @@ use rstest::rstest;
use starknet_types_core::felt::Felt;

use super::utils::{
compress,
felt_from_bits_le,
get_bucket_offsets,
get_n_elms_per_felt,
pack_usize_in_felts,
Expand All @@ -20,10 +23,14 @@ use super::utils::{
BucketElementTrait,
Buckets,
CompressionSet,
COMPRESSION_VERSION,
HEADER_ELM_BOUND,
N_UNIQUE_BUCKETS,
TOTAL_N_BUCKETS,
};
use crate::hints::error::OsHintError;

const HEADER_LEN: usize = 1 + 1 + TOTAL_N_BUCKETS;
// Utils

pub fn unpack_felts<const LENGTH: usize>(
Expand Down Expand Up @@ -61,6 +68,81 @@ pub fn unpack_felts_to_usize(compressed: &[Felt], n_elms: usize, elm_bound: u32)
result
}

/// Decompresses the given compressed data.
pub fn decompress(compressed: &mut impl Iterator<Item = Felt>) -> Vec<Felt> {
fn unpack_chunk<const LENGTH: usize>(
compressed: &mut impl Iterator<Item = Felt>,
n_elms: usize,
) -> Vec<Felt> {
let n_elms_per_felt = BitLength::min_bit_length(LENGTH).unwrap().n_elems_in_felt();
let n_packed_felts = n_elms.div_ceil(n_elms_per_felt);
let compressed_chunk: Vec<_> = compressed.take(n_packed_felts).collect();
unpack_felts(&compressed_chunk, n_elms)
.into_iter()
.map(|bits: BitsArray<LENGTH>| felt_from_bits_le(&bits.0).unwrap())
.collect()
}

fn unpack_chunk_to_usize(
compressed: &mut impl Iterator<Item = Felt>,
n_elms: usize,
elm_bound: u32,
) -> Vec<usize> {
let n_elms_per_felt = get_n_elms_per_felt(elm_bound);
let n_packed_felts = n_elms.div_ceil(n_elms_per_felt);

let compressed_chunk: Vec<_> = compressed.take(n_packed_felts).collect();
unpack_felts_to_usize(&compressed_chunk, n_elms, elm_bound)
}

let header = unpack_chunk_to_usize(compressed, HEADER_LEN, HEADER_ELM_BOUND);
let version = &header[0];
assert!(version == &usize::from(COMPRESSION_VERSION), "Unsupported compression version.");

let data_len = &header[1];
let unique_value_bucket_lengths: Vec<usize> = header[2..2 + N_UNIQUE_BUCKETS].to_vec();
let n_repeating_values = &header[2 + N_UNIQUE_BUCKETS];

let mut unique_values = Vec::new();
unique_values.extend(compressed.take(unique_value_bucket_lengths[0])); // 252 bucket.
unique_values.extend(unpack_chunk::<125>(compressed, unique_value_bucket_lengths[1]));
unique_values.extend(unpack_chunk::<83>(compressed, unique_value_bucket_lengths[2]));
unique_values.extend(unpack_chunk::<62>(compressed, unique_value_bucket_lengths[3]));
unique_values.extend(unpack_chunk::<31>(compressed, unique_value_bucket_lengths[4]));
unique_values.extend(unpack_chunk::<15>(compressed, unique_value_bucket_lengths[5]));

let repeating_value_pointers = unpack_chunk_to_usize(
compressed,
*n_repeating_values,
unique_values.len().try_into().unwrap(),
);

let repeating_values: Vec<_> =
repeating_value_pointers.iter().map(|ptr| unique_values[*ptr]).collect();

let mut all_values = unique_values;
all_values.extend(repeating_values);

let bucket_index_per_elm: Vec<usize> =
unpack_chunk_to_usize(compressed, *data_len, TOTAL_N_BUCKETS.try_into().unwrap());

let all_bucket_lengths: Vec<usize> =
unique_value_bucket_lengths.into_iter().chain([*n_repeating_values]).collect();

let bucket_offsets = get_bucket_offsets(&all_bucket_lengths);

let mut bucket_offset_trackers: Vec<_> = bucket_offsets;

let mut result = Vec::new();
for bucket_index in bucket_index_per_elm {
let offset = &mut bucket_offset_trackers[bucket_index];
let value = all_values[*offset];
*offset += 1;
result.push(value);
}
result
}

// Tests

#[rstest]
Expand Down Expand Up @@ -179,3 +261,82 @@ fn test_update_with_unique_values(
assert_eq!(expected_n_repeating_values, compression_set.n_repeating_values());
assert_eq!(expected_repeating_value_pointers, compression_set.get_repeating_value_pointers());
}

// These values are calculated by importing the module and running the compression method
// ```py
// # import compress from compression
// def main() -> int:
// print(compress([2,3,1]))
// return 0
// ```
#[rstest]
#[case::single_value_1(vec![1u32], vec!["0x100000000000000000000000000000100000", "0x1", "0x5"])]
#[case::single_value_2(vec![2u32], vec!["0x100000000000000000000000000000100000", "0x2", "0x5"])]
#[case::single_value_3(vec![10u32], vec!["0x100000000000000000000000000000100000", "0xA", "0x5"])]
#[case::two_values(vec![1u32, 2], vec!["0x200000000000000000000000000000200000", "0x10001", "0x28"])]
#[case::three_values(vec![2u32, 3, 1], vec!["0x300000000000000000000000000000300000", "0x40018002", "0x11d"])]
#[case::four_values(vec![1u32, 2, 3, 4], vec!["0x400000000000000000000000000000400000", "0x8000c0010001", "0x7d0"])]
#[case::extracted_kzg_example(vec![1u32, 1, 6, 1991, 66, 0], vec!["0x10000500000000000000000000000000000600000", "0x841f1c0030001", "0x0", "0x17eff"])]

fn test_compress_decompress(#[case] input: Vec<u32>, #[case] expected: Vec<&str>) {
let data: Vec<_> = input.into_iter().map(Felt::from).collect();
let compressed = compress(&data);
let expected: Vec<_> = expected.iter().map(|s| Felt::from_hex_unchecked(s)).collect();
assert_eq!(compressed, expected);

let decompressed = decompress(&mut compressed.into_iter());
assert_eq!(decompressed, data);
}

#[rstest]
#[case::no_values(
vec![],
0, // No buckets.
None,
)]
#[case::single_value_1(
vec![Felt::from(7777777)],
1, // A single bucket with one value.
Some(300), // 1 header, 1 value, 1 pointer
)]
#[case::large_duplicates(
vec![Felt::from(BigUint::from(2_u8).pow(250)); 100],
1, // Should remove duplicated values.
Some(5),
)]
#[case::small_values(
(0..0x8000).map(Felt::from).collect(),
2048, // = 2**15/(251/15), as all elements are packed in the 15-bits bucket.
Some(7),
)]
#[case::mixed_buckets(
(0..252).map(|i| Felt::from(BigUint::from(2_u8).pow(i))).collect(),
1 + 2 + 8 + 7 + 21 + 127, // All buckets are involved here.
Some(67), // More than half of the values are in the biggest (252-bit) bucket.
)]
fn test_compression_length(
#[case] data: Vec<Felt>,
#[case] expected_unique_values_packed_length: usize,
#[case] expected_compression_percents: Option<usize>,
) {
let compressed = compress(&data);

let n_unique_values = data.iter().collect::<HashSet<_>>().len();
let n_repeated_values = data.len() - n_unique_values;
let expected_repeated_value_pointers_packed_length =
n_repeated_values.div_ceil(get_n_elms_per_felt(u32::try_from(n_unique_values).unwrap()));
let expected_bucket_indices_packed_length =
data.len().div_ceil(get_n_elms_per_felt(u32::try_from(TOTAL_N_BUCKETS).unwrap()));

assert_eq!(
compressed.len(),
1 + expected_unique_values_packed_length
+ expected_repeated_value_pointers_packed_length
+ expected_bucket_indices_packed_length
);

if let Some(expected_compression_percents_val) = expected_compression_percents {
assert_eq!(100 * compressed.len() / data.len(), expected_compression_percents_val);
}
assert_eq!(data, decompress(&mut compressed.into_iter()));
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,17 @@ use std::hash::Hash;

use indexmap::IndexMap;
use num_bigint::BigUint;
use num_traits::Zero;
use num_traits::{ToPrimitive, Zero};
use starknet_types_core::felt::Felt;
use strum::EnumCount;
use strum_macros::Display;

use crate::hints::error::OsHintError;

pub(crate) const COMPRESSION_VERSION: u8 = 0;
pub(crate) const HEADER_ELM_N_BITS: usize = 20;
pub(crate) const HEADER_ELM_BOUND: u32 = 1 << HEADER_ELM_N_BITS;

pub(crate) const N_UNIQUE_BUCKETS: usize = BitLength::COUNT;
/// Number of buckets, including the repeating values bucket.
pub(crate) const TOTAL_N_BUCKETS: usize = N_UNIQUE_BUCKETS + 1;
Expand Down Expand Up @@ -400,6 +404,42 @@ impl CompressionSet {
}
}

/// Compresses the data provided to output a Vec of compressed Felts.
pub(crate) fn compress(data: &[Felt]) -> Vec<Felt> {
assert!(
data.len() < HEADER_ELM_BOUND.to_usize().expect("usize overflow"),
"Data is too long: {} >= {HEADER_ELM_BOUND}.",
data.len()
);

let compression_set = CompressionSet::new(data);

let unique_value_bucket_lengths = compression_set.get_unique_value_bucket_lengths();
let n_unique_values: usize = unique_value_bucket_lengths.iter().sum();

let header: Vec<usize> = [COMPRESSION_VERSION.into(), data.len()]
.into_iter()
.chain(unique_value_bucket_lengths)
.chain([compression_set.n_repeating_values()])
.collect();

let packed_header = pack_usize_in_felts(&header, HEADER_ELM_BOUND);
let packed_repeating_value_pointers = pack_usize_in_felts(
&compression_set.get_repeating_value_pointers(),
u32::try_from(n_unique_values).expect("Too many unique values"),
);
let packed_bucket_index_per_elm = pack_usize_in_felts(
&compression_set.bucket_index_per_elm,
u32::try_from(TOTAL_N_BUCKETS).expect("Too many buckets"),
);

let unique_values = compression_set.pack_unique_values();
[packed_header, unique_values, packed_repeating_value_pointers, packed_bucket_index_per_elm]
.into_iter()
.flatten()
.collect()
}

/// Calculates the number of elements with the same bit length as the element bound, that can fit
/// into a single felt value.
pub fn get_n_elms_per_felt(elm_bound: u32) -> usize {
Expand Down

0 comments on commit 13fbe1d

Please sign in to comment.