From c19132054f27865a11f9ae2534ad53985f2df02e Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Tue, 26 Nov 2024 12:05:26 +0000 Subject: [PATCH 1/3] test: add test showing simple replace breaking on nested --- hugr-core/src/builder.rs | 2 +- hugr-core/src/hugr/rewrite/simple_replace.rs | 48 +++++++++++++++++++- 2 files changed, 48 insertions(+), 2 deletions(-) diff --git a/hugr-core/src/builder.rs b/hugr-core/src/builder.rs index f40703243..8bde38cc6 100644 --- a/hugr-core/src/builder.rs +++ b/hugr-core/src/builder.rs @@ -256,7 +256,7 @@ pub(crate) mod test { pub(super) const QB: Type = crate::extension::prelude::QB_T; /// Wire up inputs of a Dataflow container to the outputs. - pub(super) fn n_identity( + pub(crate) fn n_identity( dataflow_builder: T, ) -> Result { let w = dataflow_builder.input_wires(); diff --git a/hugr-core/src/hugr/rewrite/simple_replace.rs b/hugr-core/src/hugr/rewrite/simple_replace.rs index 3018adce0..0ce9e422a 100644 --- a/hugr-core/src/hugr/rewrite/simple_replace.rs +++ b/hugr-core/src/hugr/rewrite/simple_replace.rs @@ -221,11 +221,12 @@ pub(in crate::hugr::rewrite) mod test { use rstest::{fixture, rstest}; use std::collections::{HashMap, HashSet}; + use crate::builder::test::n_identity; use crate::builder::{ endo_sig, inout_sig, BuildError, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, HugrBuilder, ModuleBuilder, }; - use crate::extension::prelude::BOOL_T; + use crate::extension::prelude::{BOOL_T, QB_T}; use crate::extension::{ExtensionSet, EMPTY_REG, PRELUDE_REGISTRY}; use crate::hugr::views::{HugrView, SiblingSubgraph}; use crate::hugr::{Hugr, HugrMut, Rewrite}; @@ -774,6 +775,51 @@ pub(in crate::hugr::rewrite) mod test { assert_eq!(hugr.node_count(), 4); } + #[rstest] + fn test_nested_replace(dfg_hugr2: Hugr) { + // replace a node with a hugr with children + + let mut h = dfg_hugr2; + let h_node = h + .nodes() + .find(|node: &Node| *h.get_optype(*node) == h_gate().into()) + .unwrap(); + + // build a nested identity hugr + let mut nest_build = DFGBuilder::new(Signature::new_endo(QB_T)).unwrap(); + let [input] = nest_build.input_wires_arr(); + let inner_build = nest_build.dfg_builder_endo([(QB_T, input)]).unwrap(); + let inner_dfg = n_identity(inner_build).unwrap(); + let inner_dfg_node = inner_dfg.node(); + let replacement = nest_build + .finish_prelude_hugr_with_outputs([inner_dfg.out_wire(0)]) + .unwrap(); + let subgraph = SiblingSubgraph::try_from_nodes(vec![h_node], &h).unwrap(); + let nu_inp = vec![( + (inner_dfg_node, IncomingPort::from(0)), + (h_node, IncomingPort::from(0)), + )] + .into_iter() + .collect(); + + let nu_out = vec![( + (h.get_io(h.root()).unwrap()[1], IncomingPort::from(0)), + IncomingPort::from(0), + )] + .into_iter() + .collect(); + + let rewrite = SimpleReplacement::new(subgraph, replacement, nu_inp, nu_out); + + assert_eq!(h.node_count(), 4); + + rewrite.apply(&mut h).unwrap_or_else(|e| panic!("{e}")); + assert_eq!(h.update_validate(&PRELUDE_REGISTRY), Ok(())); + + println!("{}", h.mermaid_string()); + assert_eq!(h.node_count(), 6); + } + use crate::hugr::rewrite::replace::Replacement; fn to_replace(h: &impl HugrView, s: SimpleReplacement) -> Replacement { use crate::hugr::rewrite::replace::{NewEdgeKind, NewEdgeSpec}; From 28e65f68108501481fee962083755b360d7bedbd Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Tue, 26 Nov 2024 13:43:13 +0000 Subject: [PATCH 2/3] attempt to fix using descendents --- hugr-core/src/hugr/rewrite/simple_replace.rs | 49 ++++++++++++++------ 1 file changed, 36 insertions(+), 13 deletions(-) diff --git a/hugr-core/src/hugr/rewrite/simple_replace.rs b/hugr-core/src/hugr/rewrite/simple_replace.rs index 0ce9e422a..24b382986 100644 --- a/hugr-core/src/hugr/rewrite/simple_replace.rs +++ b/hugr-core/src/hugr/rewrite/simple_replace.rs @@ -2,12 +2,15 @@ use std::collections::{HashMap, HashSet}; -use crate::hugr::views::SiblingSubgraph; +use crate::hugr::hugrmut::InsertionResult; +use crate::hugr::views::{DescendantsGraph, HierarchyView, SiblingSubgraph}; use crate::hugr::{HugrMut, HugrView, NodeMetadataMap, Rewrite}; use crate::ops::{OpTag, OpTrait, OpType}; use crate::{Hugr, IncomingPort, Node, OutgoingPort}; use thiserror::Error; +use super::inline_dfg::{InlineDFG, InlineDFGError}; + /// Specification of a simple replacement operation. #[derive(Debug, Clone)] pub struct SimpleReplacement { @@ -74,19 +77,23 @@ impl Rewrite for SimpleReplacement { return Err(SimpleReplacementError::InvalidRemovedNode()); } } - // 3. Do the replacement. - // 3.1. Add copies of all replacement nodes and edges to h. Exclude Input/Output nodes. - // Create map from old NodeIndex (in self.replacement) to new NodeIndex (in self). + // // 3. Do the replacement. + // // 3.1. Add copies of all replacement nodes and edges to h. Exclude Input/Output nodes. + // // Create map from old NodeIndex (in self.replacement) to new NodeIndex (in self). let mut index_map: HashMap = HashMap::new(); - let replacement_nodes = self - .replacement - .children(self.replacement.root()) + let replace_io = self.replacement.get_io(self.replacement.root()).unwrap(); + let replace_ignore_nodes = [replace_io[0], replace_io[1], self.replacement.root()]; + let descendants: DescendantsGraph = + DescendantsGraph::try_new(&self.replacement, self.replacement.root()) + .expect("parent already checked."); + let replacement_inner_nodes = descendants + .nodes() + .filter(|n| !replace_ignore_nodes.contains(n)) .collect::>(); // slice of nodes omitting Input and Output: - let replacement_inner_nodes = &replacement_nodes[2..]; + // let replacement_inner_nodes = &replacement_nodes[2..]; let self_output_node = h.children(parent).nth(1).unwrap(); - let replacement_output_node = *replacement_nodes.get(1).unwrap(); - for &node in replacement_inner_nodes { + for &node in replacement_inner_nodes.iter() { // Add the nodes. let op: &OpType = self.replacement.get_optype(node); let new_node = h.add_node_after(self_output_node, op.clone()); @@ -97,7 +104,7 @@ impl Rewrite for SimpleReplacement { h.overwrite_node_metadata(new_node, meta); } // Add edges between all newly added nodes matching those in replacement. - for &node in replacement_inner_nodes { + for &node in replacement_inner_nodes.iter() { let new_node = index_map.get(&node).unwrap(); for outport in self.replacement.node_outputs(node) { for target in self.replacement.linked_inputs(node, outport) { @@ -109,6 +116,14 @@ impl Rewrite for SimpleReplacement { } } + // let InsertionResult { + // new_root, + // node_map: index_map, + // } = h.insert_hugr(parent, self.replacement.clone()); + // println!("{}", h.mermaid_string()); + // dbg!(new_root); + // h.apply_rewrite(InlineDFG(new_root.into()))?; + // Now we proceed to connect the edges between the newly inserted // replacement and the rest of the graph. // @@ -136,6 +151,10 @@ impl Rewrite for SimpleReplacement { )); } } + let replacement_output_node = self + .replacement + .get_io(self.replacement.root()) + .expect("parent already checked.")[1]; // 3.3. For each q = self.nu_out[p] such that the predecessor of q is not an Input port, add an // edge from (the new copy of) the predecessor of q to p. for ((rem_out_node, rem_out_port), rep_out_port) in &self.nu_out { @@ -213,6 +232,9 @@ pub enum SimpleReplacementError { /// Node in replacement graph is invalid. #[error("A node in the replacement graph is invalid.")] InvalidReplacementNode(), + /// Inlining replacement failed. + #[error("Inlining replacement failed: {0}")] + InliningFailed(#[from] InlineDFGError), } #[cfg(test)] @@ -814,9 +836,10 @@ pub(in crate::hugr::rewrite) mod test { assert_eq!(h.node_count(), 4); rewrite.apply(&mut h).unwrap_or_else(|e| panic!("{e}")); - assert_eq!(h.update_validate(&PRELUDE_REGISTRY), Ok(())); - println!("{}", h.mermaid_string()); + h.update_validate(&PRELUDE_REGISTRY) + .unwrap_or_else(|e| panic!("{e}")); + assert_eq!(h.node_count(), 6); } From 4d1427e3db9e910d401da8cb6a3621f456446fe4 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Tue, 26 Nov 2024 14:28:44 +0000 Subject: [PATCH 3/3] fix: hierarchical simple replacement using insert_hugr Closes #1715 --- hugr-core/src/hugr/rewrite/simple_replace.rs | 74 ++++++-------------- 1 file changed, 23 insertions(+), 51 deletions(-) diff --git a/hugr-core/src/hugr/rewrite/simple_replace.rs b/hugr-core/src/hugr/rewrite/simple_replace.rs index 24b382986..026a996b8 100644 --- a/hugr-core/src/hugr/rewrite/simple_replace.rs +++ b/hugr-core/src/hugr/rewrite/simple_replace.rs @@ -3,13 +3,14 @@ use std::collections::{HashMap, HashSet}; use crate::hugr::hugrmut::InsertionResult; -use crate::hugr::views::{DescendantsGraph, HierarchyView, SiblingSubgraph}; -use crate::hugr::{HugrMut, HugrView, NodeMetadataMap, Rewrite}; +pub use crate::hugr::internal::HugrMutInternals; +use crate::hugr::views::SiblingSubgraph; +use crate::hugr::{HugrMut, HugrView, Rewrite}; use crate::ops::{OpTag, OpTrait, OpType}; use crate::{Hugr, IncomingPort, Node, OutgoingPort}; use thiserror::Error; -use super::inline_dfg::{InlineDFG, InlineDFGError}; +use super::inline_dfg::InlineDFGError; /// Specification of a simple replacement operation. #[derive(Debug, Clone)] @@ -65,7 +66,7 @@ impl Rewrite for SimpleReplacement { unimplemented!() } - fn apply(mut self, h: &mut impl HugrMut) -> Result { + fn apply(self, h: &mut impl HugrMut) -> Result { let parent = self.subgraph.get_parent(h); // 1. Check the parent node exists and is a DataflowParent. if !OpTag::DataflowParent.is_superset(h.get_optype(parent).tag()) { @@ -77,52 +78,24 @@ impl Rewrite for SimpleReplacement { return Err(SimpleReplacementError::InvalidRemovedNode()); } } - // // 3. Do the replacement. - // // 3.1. Add copies of all replacement nodes and edges to h. Exclude Input/Output nodes. - // // Create map from old NodeIndex (in self.replacement) to new NodeIndex (in self). - let mut index_map: HashMap = HashMap::new(); - let replace_io = self.replacement.get_io(self.replacement.root()).unwrap(); - let replace_ignore_nodes = [replace_io[0], replace_io[1], self.replacement.root()]; - let descendants: DescendantsGraph = - DescendantsGraph::try_new(&self.replacement, self.replacement.root()) - .expect("parent already checked."); - let replacement_inner_nodes = descendants - .nodes() - .filter(|n| !replace_ignore_nodes.contains(n)) - .collect::>(); - // slice of nodes omitting Input and Output: - // let replacement_inner_nodes = &replacement_nodes[2..]; - let self_output_node = h.children(parent).nth(1).unwrap(); - for &node in replacement_inner_nodes.iter() { - // Add the nodes. - let op: &OpType = self.replacement.get_optype(node); - let new_node = h.add_node_after(self_output_node, op.clone()); - index_map.insert(node, new_node); - - // Move the metadata - let meta: Option = self.replacement.take_node_metadata(node); - h.overwrite_node_metadata(new_node, meta); + // 3. Do the replacement. + // 3.1. Insert the replacement as a whole. + let InsertionResult { + new_root, + node_map: index_map, + } = h.insert_hugr(parent, self.replacement.clone()); + + // remove the Input and Output nodes from the replacement graph + let replace_children = h.children(new_root).collect::>(); + for &io in &replace_children[..2] { + h.remove_node(io); } - // Add edges between all newly added nodes matching those in replacement. - for &node in replacement_inner_nodes.iter() { - let new_node = index_map.get(&node).unwrap(); - for outport in self.replacement.node_outputs(node) { - for target in self.replacement.linked_inputs(node, outport) { - if self.replacement.get_optype(target.0).tag() != OpTag::Output { - let new_target = index_map.get(&target.0).unwrap(); - h.connect(*new_node, outport, *new_target, target.1); - } - } - } + // make all replacement top level children children of the parent + for &child in &replace_children[2..] { + h.set_parent(child, parent); } - - // let InsertionResult { - // new_root, - // node_map: index_map, - // } = h.insert_hugr(parent, self.replacement.clone()); - // println!("{}", h.mermaid_string()); - // dbg!(new_root); - // h.apply_rewrite(InlineDFG(new_root.into()))?; + // remove the replacement root (which now has no children and no edges) + h.remove_node(new_root); // Now we proceed to connect the edges between the newly inserted // replacement and the rest of the graph. @@ -807,7 +780,7 @@ pub(in crate::hugr::rewrite) mod test { .find(|node: &Node| *h.get_optype(*node) == h_gate().into()) .unwrap(); - // build a nested identity hugr + // build a nested identity dfg let mut nest_build = DFGBuilder::new(Signature::new_endo(QB_T)).unwrap(); let [input] = nest_build.input_wires_arr(); let inner_build = nest_build.dfg_builder_endo([(QB_T, input)]).unwrap(); @@ -825,7 +798,7 @@ pub(in crate::hugr::rewrite) mod test { .collect(); let nu_out = vec![( - (h.get_io(h.root()).unwrap()[1], IncomingPort::from(0)), + (h.get_io(h.root()).unwrap()[1], IncomingPort::from(1)), IncomingPort::from(0), )] .into_iter() @@ -836,7 +809,6 @@ pub(in crate::hugr::rewrite) mod test { assert_eq!(h.node_count(), 4); rewrite.apply(&mut h).unwrap_or_else(|e| panic!("{e}")); - println!("{}", h.mermaid_string()); h.update_validate(&PRELUDE_REGISTRY) .unwrap_or_else(|e| panic!("{e}"));