Skip to content

Commit

Permalink
Dynamic memory management preset + updated wgpu buffer memory managem…
Browse files Browse the repository at this point in the history
…ent (#1962)


---------

Co-authored-by: mepatrick73 <pameu17@ulaval.ca>
  • Loading branch information
nathanielsimard and mepatrick73 authored Jul 4, 2024
1 parent 5236e12 commit 51aea94
Show file tree
Hide file tree
Showing 6 changed files with 191 additions and 130 deletions.
11 changes: 9 additions & 2 deletions crates/burn-compute/benches/dynamic.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,22 @@
use std::collections::LinkedList;

use burn_compute::{
memory_management::{dynamic::DynamicMemoryManagement, MemoryManagement},
memory_management::{
dynamic::{DynamicMemoryManagement, DynamicMemoryManagementOptions},
MemoryManagement,
},
storage::BytesStorage,
};

const MB: usize = 1024 * 1024;

fn main() {
let start = std::time::Instant::now();
let storage = BytesStorage::default();
let mut mm = DynamicMemoryManagement::new(storage);
let mut mm = DynamicMemoryManagement::new(
storage,
DynamicMemoryManagementOptions::preset(2048 * MB, 32),
);
let mut handles = LinkedList::new();
for _ in 0..100 * 2048 {
if handles.len() >= 4000 {
Expand Down
171 changes: 117 additions & 54 deletions crates/burn-compute/src/memory_management/dynamic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,39 +3,107 @@ use super::memory_pool::{
SmallMemoryPool,
};
use crate::storage::ComputeStorage;
use alloc::vec::Vec;

use super::MemoryManagement;

/// Reserves and keeps track of chunks of memory in the storage, and slices upon these chunks.
pub struct DynamicMemoryManagement<Storage> {
min_chunk_alignment_offset: usize,
small_memory_pool: SmallMemoryPool,
small_medium_memory_pool: MemoryPool,
medium_memory_pool: MemoryPool,
main_memory_pool: MemoryPool,
pools: Vec<MemoryPool>,
options: Vec<MemoryPoolOptions>,
storage: Storage,
}

/// Options to initialize a [dynamic memory management](DynamicMemoryManagement).
#[derive(new, Debug)]
pub struct DynamicMemoryManagementOptions {
pools: Vec<MemoryPoolOptions>,
min_chunk_alignment_offset: usize,
}

/// Options to create a memory pool.
#[derive(Debug)]
pub struct MemoryPoolOptions {
/// The amount of bytes used for each chunk in the memory pool.
pub chunk_size: usize,
/// The number of chunks allocated directly at creation.
///
/// Useful when you know in advance how much memory you'll need.
pub chunk_num_prealloc: usize,
/// The max size in bytes a slice can take in the pool.
pub slice_max_size: usize,
}

impl DynamicMemoryManagementOptions {
/// Creates the options from device limits.
pub fn preset(max_chunk_size: usize, min_chunk_alignment_offset: usize) -> Self {
// Rounding down to a factor of 8.
let max_chunk_size = (max_chunk_size / 8) * 8;

const MB: usize = 1024 * 1024;

let mut pools = Vec::new();

pools.push(MemoryPoolOptions {
chunk_size: max_chunk_size,
chunk_num_prealloc: 0,
slice_max_size: max_chunk_size,
});

let mut current = max_chunk_size;

while current >= 32 * MB {
current /= 4;

pools.push(MemoryPoolOptions {
chunk_size: current,
chunk_num_prealloc: 0,
// Creating max slices lower than the chunk size reduces fragmentation.
slice_max_size: current / 2usize.pow(pools.len() as u32),
});
}

Self {
pools,
min_chunk_alignment_offset,
}
}
}

impl<Storage: ComputeStorage> DynamicMemoryManagement<Storage> {
/// Creates a new instance using the given storage, merging_strategy strategy and slice strategy.
pub fn new(storage: Storage) -> Self {
let main_memory_pool = MemoryPool::new(
MemoryExtensionStrategy::new_period_tick(10),
RoundingStrategy::FixedAmount(1024 * 1024 * 1024),
);
let medium_memory_pool = MemoryPool::new(
MemoryExtensionStrategy::Never,
RoundingStrategy::FixedAmount(1024 * 1024 * 200),
);
let small_medium_memory_pool = MemoryPool::new(
MemoryExtensionStrategy::Never,
RoundingStrategy::FixedAmount(1024 * 1024 * 2),
);
let small_memory_pool = SmallMemoryPool::new();
pub fn new(mut storage: Storage, mut options: DynamicMemoryManagementOptions) -> Self {
options
.pools
.sort_by(|pool1, pool2| usize::cmp(&pool1.slice_max_size, &pool2.slice_max_size));

let min_chunk_alignment_offset = options.min_chunk_alignment_offset;

let pools = options
.pools
.iter()
.map(|option| {
let mut pool = MemoryPool::new(
MemoryExtensionStrategy::Never,
RoundingStrategy::FixedAmount(option.chunk_size),
min_chunk_alignment_offset,
);

for _ in 0..option.chunk_num_prealloc {
pool.alloc(&mut storage, option.chunk_size, || {});
}

pool
})
.collect();

Self {
small_memory_pool,
small_medium_memory_pool,
main_memory_pool,
medium_memory_pool,
min_chunk_alignment_offset,
small_memory_pool: SmallMemoryPool::new(min_chunk_alignment_offset),
pools,
options: options.pools,
storage,
}
}
Expand All @@ -62,50 +130,45 @@ impl<Storage: ComputeStorage> MemoryManagement<Storage> for DynamicMemoryManagem
return handle;
}

if let Some(handle) = self
.small_medium_memory_pool
.get(&mut self.storage, &binding)
{
return handle;
for pool in &mut self.pools {
if let Some(handle) = pool.get(&mut self.storage, &binding) {
return handle;
}
}

if let Some(handle) = self.medium_memory_pool.get(&mut self.storage, &binding) {
return handle;
panic!("No handle found in memory pools");
}

fn reserve<Sync: FnOnce()>(&mut self, size: usize, sync: Sync) -> Self::Handle {
if size <= self.min_chunk_alignment_offset {
return self
.small_memory_pool
.reserve(&mut self.storage, size, sync);
}

if let Some(handle) = self.main_memory_pool.get(&mut self.storage, &binding) {
return handle;
for (index, option) in self.options.iter().enumerate() {
if size <= option.slice_max_size {
let pool = &mut self.pools[index];
return pool.reserve(&mut self.storage, size, sync);
}
}

panic!("No handle found in the small and main memory pool");
panic!("No memory pool big enough to reserve {size} bytes.");
}

fn reserve<Sync: FnOnce()>(&mut self, size: usize, sync: Sync) -> Self::Handle {
if size <= 32 {
self.small_memory_pool
.reserve(&mut self.storage, size, sync)
} else if size <= 2 * 1024 * 1024 {
self.small_medium_memory_pool
.reserve(&mut self.storage, size, sync)
} else if size < 200 * 1024 * 1024 {
self.medium_memory_pool
.reserve(&mut self.storage, size, sync)
} else {
self.main_memory_pool.reserve(&mut self.storage, size, sync)
fn alloc<Sync: FnOnce()>(&mut self, size: usize, sync: Sync) -> Self::Handle {
if size <= self.min_chunk_alignment_offset {
return self.small_memory_pool.alloc(&mut self.storage, size, sync);
}
}

fn alloc<Sync: FnOnce()>(&mut self, size: usize, sync: Sync) -> Self::Handle {
if size <= 32 {
self.small_memory_pool.alloc(&mut self.storage, size, sync)
} else if size <= 2 * 1024 * 1024 {
self.small_medium_memory_pool
.alloc(&mut self.storage, size, sync)
} else if size <= 200 * 1024 * 1024 {
self.medium_memory_pool.alloc(&mut self.storage, size, sync)
} else {
self.main_memory_pool.alloc(&mut self.storage, size, sync)
for (index, option) in self.options.iter().enumerate() {
if size <= option.slice_max_size {
let pool = &mut self.pools[index];
return pool.alloc(&mut self.storage, size, sync);
}
}

panic!("No memory pool big enough to alloc {size} bytes.");
}

fn dealloc(&mut self, _binding: Self::Binding) {
Expand Down
Loading

0 comments on commit 51aea94

Please sign in to comment.