Skip to content

Commit

Permalink
Cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard committed May 2, 2024
1 parent 365bf45 commit fd772ac
Showing 1 changed file with 28 additions and 23 deletions.
51 changes: 28 additions & 23 deletions crates/burn-autodiff/src/runtime/memory_management.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::{graph::Node, tensor::NodeRefCount, NodeID};
use crate::{tensor::NodeRefCount, NodeID};
use std::{
collections::{HashMap, HashSet},
mem,
Expand All @@ -19,28 +19,6 @@ enum NodeMemoryStatus {
Unknown,
}

// Wrapper over hash set for fast popping of any node
#[derive(new, Default)]
struct PopNodeSet {
hash_set: HashSet<NodeID>,
}

impl PopNodeSet {
fn pop(&mut self) -> Option<NodeID> {
self.hash_set
.iter()
.next()
.copied()
.and_then(|node_id| self.hash_set.take(&node_id))
}
fn contains(&self, node_id: &NodeID) -> bool {
self.hash_set.contains(node_id)
}
fn insert(&mut self, node_id: NodeID) {
self.hash_set.insert(node_id);
}
}

impl GraphMemoryManagement {
/// Register a new node with its parent.
pub fn register(&mut self, node: NodeRefCount, parents: Vec<NodeID>) {
Expand Down Expand Up @@ -245,3 +223,30 @@ impl GraphMemoryManagement {
}
}
}

/// Wrapper over hash set for fast popping of any node
#[derive(new, Default)]
struct PopNodeSet {
hash_set: HashSet<NodeID>,
}

impl PopNodeSet {
#[inline(always)]
fn pop(&mut self) -> Option<NodeID> {
self.hash_set
.iter()
.next()
.copied()
.and_then(|node_id| self.hash_set.take(&node_id))
}

#[inline(always)]
fn contains(&self, node_id: &NodeID) -> bool {
self.hash_set.contains(node_id)
}

#[inline(always)]
fn insert(&mut self, node_id: NodeID) {
self.hash_set.insert(node_id);
}
}

0 comments on commit fd772ac

Please sign in to comment.