Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Autodiff Memory Management: BFS #1710

Merged
merged 15 commits into from
May 3, 2024
Merged

Conversation

louisfd
Copy link
Member

@louisfd louisfd commented Apr 29, 2024

Use breadth first search algorithm instead of pure recursivity in autodiff memory management, because nodes could be visited recursively way too many times in some settings.
Fix #1702

Comment on lines 169 to 171
if !visited.contains(&parent) {
to_visit.push((parent, next_mode.clone()));
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I understand it, we can still visit multiple times the same parent with different modes, but once we visited the parent once, we can't register new modes. Wouldn't it make sense to register the modes as well with priority (If TagAsUseful > Explore)?

The visited could be an HashMap containing the mode used. You can't register a parent with Explore when the parent was already visited, but you can with the mode TagAsUseful if the parent was visited with the mode Explore?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unlike the NodeMemoryStatus for which there is one status per node (at first Unknown, then maybe Unavailable after the first propagation, then maybe Useful after this propagation), there is not one Mode per node.

The Mode made more sense in the previous form of the algorithm: we started in exploration until we found a node to tag as useful, then for this node to be usable we had to tag all of its parents as useful as well, so the algorithm switched to TagAsUseful for all of this branch. But with the visited approach we have to remember to go back to this mode when we get to a parent of a useful node. So the mode becomes tied to nodes in the vec and it's not very elegant as it was supposed to be a status of the algorithm, not the node. I think I can remove the concept of Mode altogether, it will be clearer.

Copy link

codecov bot commented Apr 30, 2024

Codecov Report

Attention: Patch coverage is 99.23077% with 1 lines in your changes are missing coverage. Please review.

Project coverage is 86.44%. Comparing base (5d959e2) to head (fd772ac).

Files Patch % Lines
...tes/burn-autodiff/src/runtime/memory_management.rs 98.82% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1710      +/-   ##
==========================================
+ Coverage   86.42%   86.44%   +0.01%     
==========================================
  Files         697      697              
  Lines       82645    82729      +84     
==========================================
+ Hits        71429    71513      +84     
  Misses      11216    11216              

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Comment on lines 149 to 151
if !visited_as_useful.contains(&parent)
&& (Some(&NodeMemoryStatus::Useful) == self.statuses.get(&node_id)
|| !visited_as_unknown.contains(&parent))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would refactor that a bit, very hard to read:

if visited_as_useful.contains(&parents); {
    continue;
}

let is_useful = Some(&NodeMemoryStatus::Useful) == self.statuses.get(&node_id);

if is_useful || !visited_as_unknown.contains(&parents) {
   to_visit.push((parent, Some(node_id)))
}

And I actually think there is still a performance problem and even a bug. I think it should be:

if visited_as_useful.contains(&parents); {
    continue;
}

let is_useful = Some(&NodeMemoryStatus::Useful) == self.statuses.get(&node_id);

if is_useful {
   to_visit.push((parent, Some(node_id)))
}

Since the vector to_visit is already filled with all nodes, you don't need to keep track of nodes that were visited as unknown.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I rewrote the algorithm, it's more efficient and elegant

let parents = self.nodes.get(&node_id).cloned().unwrap_or(vec![]);
for parent in parents {
self.identify_leaves_and_deletables(parent, new_leaves, to_delete)
let mut visited = HashSet::new();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be initialized with the right capacity, which is new_leaves.len()

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm new_leaves is dynamically inserted during this algorithm, it's length 0 at that moment.
Also we can't really know visited length in advance because it depends if the nodes the algorithm sees are useful.

@nathanielsimard nathanielsimard merged commit a8661a2 into main May 3, 2024
15 checks passed
@nathanielsimard nathanielsimard deleted the fix/autodiff_mm/revisit_nodes branch May 3, 2024 13:45
nathanielsimard pushed a commit that referenced this pull request May 3, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

loss.backward() hangs after burn update 0.12 -> 0.13
3 participants