Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore!: Add associated type Node to HugrView #1932

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions hugr-core/src/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ pub trait NodeIndex {
fn index(self) -> usize;
}

/// A trait for nodes in the Hugr.
pub trait HugrNode: Copy + Ord + std::fmt::Debug + std::fmt::Display + std::hash::Hash {}

impl<T: Copy + Ord + std::fmt::Debug + std::fmt::Display + std::hash::Hash> HugrNode for T {}

/// A port in the incoming direction.
#[derive(
Clone, Copy, PartialEq, PartialOrd, Eq, Ord, Hash, Default, serde::Serialize, serde::Deserialize,
Expand All @@ -73,7 +78,7 @@ pub type Direction = portgraph::Direction;
)]
/// A DataFlow wire, defined by a Value-kind output port of a node
// Stores node and offset to output port
pub struct Wire(Node, OutgoingPort);
pub struct Wire<N = Node>(N, OutgoingPort);

impl Node {
/// Returns the node as a portgraph `NodeIndex`.
Expand Down Expand Up @@ -204,16 +209,16 @@ impl NodeIndex for Node {
}
}

impl Wire {
impl<N: HugrNode> Wire<N> {
/// Create a new wire from a node and a port.
#[inline]
pub fn new(node: Node, port: impl Into<OutgoingPort>) -> Self {
pub fn new(node: N, port: impl Into<OutgoingPort>) -> Self {
Self(node, port.into())
}

/// The node that this wire is connected to.
#[inline]
pub fn node(&self) -> Node {
pub fn node(&self) -> N {
self.0
}

Expand All @@ -224,9 +229,9 @@ impl Wire {
}
}

impl std::fmt::Display for Wire {
impl<N: HugrNode> std::fmt::Display for Wire<N> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Wire({}, {})", self.0.index(), self.1.index)
write!(f, "Wire({}, {})", self.0, self.1.index)
}
}

Expand All @@ -238,9 +243,9 @@ impl std::fmt::Display for Wire {
#[derive(
Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, serde::Serialize, serde::Deserialize,
)]
pub enum CircuitUnit {
pub enum CircuitUnit<N = Node> {
/// Arbitrary input wire.
Wire(Wire),
Wire(Wire<N>),
/// Index to region input.
Linear(usize),
}
Expand Down
38 changes: 20 additions & 18 deletions hugr-core/src/hugr/hugrmut.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::collections::HashMap;
use std::sync::Arc;

use portgraph::view::{NodeFilter, NodeFiltered};
use portgraph::{LinkMut, NodeIndex, PortMut, PortView, SecondaryMap};
use portgraph::{LinkMut, PortMut, PortView, SecondaryMap};

use crate::extension::ExtensionRegistry;
use crate::hugr::views::SiblingSubgraph;
Expand Down Expand Up @@ -64,7 +64,7 @@ pub trait HugrMut: HugrMutInternals {
}

/// Retrieve the complete metadata map for a node.
fn take_node_metadata(&mut self, node: Node) -> Option<NodeMetadataMap> {
fn take_node_metadata(&mut self, node: Self::Node) -> Option<NodeMetadataMap> {
if !self.valid_node(node) {
return None;
}
Expand Down Expand Up @@ -292,12 +292,14 @@ pub struct InsertionResult {
pub node_map: HashMap<Node, Node>,
}

fn translate_indices(node_map: HashMap<NodeIndex, NodeIndex>) -> HashMap<Node, Node> {
fn translate_indices(
node_map: HashMap<portgraph::NodeIndex, portgraph::NodeIndex>,
) -> HashMap<Node, Node> {
HashMap::from_iter(node_map.into_iter().map(|(k, v)| (k.into(), v.into())))
}

/// Impl for non-wrapped Hugrs. Overwrites the recursive default-impls to directly use the hugr.
impl<T: RootTagged<RootHandle = Node> + AsMut<Hugr>> HugrMut for T {
impl<T: RootTagged<RootHandle = Node, Node = Node> + AsMut<Hugr>> HugrMut for T {
fn add_node_with_parent(&mut self, parent: Node, node: impl Into<OpType>) -> Node {
let node = self.as_mut().add_node(node.into());
self.as_mut()
Expand Down Expand Up @@ -406,14 +408,14 @@ impl<T: RootTagged<RootHandle = Node> + AsMut<Hugr>> HugrMut for T {
//
// No need to compute each node's extensions here, as we merge `other.extensions` directly.
for (&node, &new_node) in node_map.iter() {
let nodetype = other.get_optype(node.into());
let nodetype = other.get_optype(other.to_node(node));
self.as_mut().op_types.set(new_node, nodetype.clone());
let meta = other.base_hugr().metadata.get(node);
self.as_mut().metadata.set(new_node, meta.clone());
}
debug_assert_eq!(
Some(&new_root.pg_index()),
node_map.get(&other.root().pg_index())
node_map.get(&other.to_pg_index(other.root()))
);
InsertionResult {
new_root,
Expand All @@ -437,7 +439,7 @@ impl<T: RootTagged<RootHandle = Node> + AsMut<Hugr>> HugrMut for T {
let node_map = insert_subgraph_internal(self.as_mut(), root, other, &portgraph);
// Update the optypes and metadata, copying them from the other graph.
for (&node, &new_node) in node_map.iter() {
let nodetype = other.get_optype(node.into());
let nodetype = other.get_optype(other.to_node(node));
self.as_mut().op_types.set(new_node, nodetype.clone());
let meta = other.base_hugr().metadata.get(node);
self.as_mut().metadata.set(new_node, meta.clone());
Expand All @@ -458,25 +460,25 @@ impl<T: RootTagged<RootHandle = Node> + AsMut<Hugr>> HugrMut for T {
///
/// This function does not update the optypes of the inserted nodes, so the
/// caller must do that.
fn insert_hugr_internal(
fn insert_hugr_internal<H: HugrView>(
hugr: &mut Hugr,
root: Node,
other: &impl HugrView,
) -> (Node, HashMap<NodeIndex, NodeIndex>) {
other: &H,
) -> (Node, HashMap<portgraph::NodeIndex, portgraph::NodeIndex>) {
let node_map = hugr
.graph
.insert_graph(&other.portgraph())
.unwrap_or_else(|e| panic!("Internal error while inserting a hugr into another: {e}"));
let other_root = node_map[&other.root().pg_index()];
let other_root = node_map[&other.to_pg_index(other.root())];

// Update hierarchy and optypes
hugr.hierarchy
.push_child(other_root, root.pg_index())
.expect("Inserting a newly-created node into the hierarchy should never fail.");
for (&node, &new_node) in node_map.iter() {
other.children(node.into()).for_each(|child| {
other.children(other.to_node(node)).for_each(|child| {
hugr.hierarchy
.push_child(node_map[&child.pg_index()], new_node)
.push_child(node_map[&other.to_pg_index(child)], new_node)
.expect("Inserting a newly-created node into the hierarchy should never fail.");
});
}
Expand Down Expand Up @@ -504,7 +506,7 @@ fn insert_subgraph_internal(
root: Node,
other: &impl HugrView,
portgraph: &impl portgraph::LinkView,
) -> HashMap<NodeIndex, NodeIndex> {
) -> HashMap<portgraph::NodeIndex, portgraph::NodeIndex> {
let node_map = hugr
.graph
.insert_graph(&portgraph)
Expand All @@ -514,8 +516,8 @@ fn insert_subgraph_internal(
// update the hierarchy with their new id.
for (&node, &new_node) in node_map.iter() {
let new_parent = other
.get_parent(node.into())
.and_then(|parent| node_map.get(&parent.pg_index()).copied())
.get_parent(other.to_node(node))
.and_then(|parent| node_map.get(&other.to_pg_index(parent)).copied())
.unwrap_or(root.pg_index());
hugr.hierarchy
.push_child(new_node, new_parent)
Expand All @@ -527,7 +529,7 @@ fn insert_subgraph_internal(

/// Panic if [`HugrView::valid_node`] fails.
#[track_caller]
pub(super) fn panic_invalid_node<H: HugrView + ?Sized>(hugr: &H, node: Node) {
pub(super) fn panic_invalid_node<H: HugrView + ?Sized>(hugr: &H, node: H::Node) {
if !hugr.valid_node(node) {
panic!(
"Received an invalid node {node} while mutating a HUGR:\n\n {}",
Expand All @@ -538,7 +540,7 @@ pub(super) fn panic_invalid_node<H: HugrView + ?Sized>(hugr: &H, node: Node) {

/// Panic if [`HugrView::valid_non_root`] fails.
#[track_caller]
pub(super) fn panic_invalid_non_root<H: HugrView + ?Sized>(hugr: &H, node: Node) {
pub(super) fn panic_invalid_non_root<H: HugrView + ?Sized>(hugr: &H, node: H::Node) {
if !hugr.valid_non_root(node) {
panic!(
"Received an invalid non-root node {node} while mutating a HUGR:\n\n {}",
Expand Down
63 changes: 53 additions & 10 deletions hugr-core/src/hugr/internal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,23 @@ pub trait HugrInternals {
where
Self: 'p;

/// The type of nodes in the Hugr.
type Node: Copy + Ord + std::fmt::Debug + std::fmt::Display + std::hash::Hash;

/// Returns a reference to the underlying portgraph.
fn portgraph(&self) -> Self::Portgraph<'_>;

/// Returns the Hugr at the base of a chain of views.
fn base_hugr(&self) -> &Hugr;

/// Return the root node of this view.
fn root_node(&self) -> Node;
fn root_node(&self) -> Self::Node;

/// Convert a node to a portgraph node index.
fn to_pg_index(&self, node: Self::Node) -> portgraph::NodeIndex;

/// Convert a portgraph node index to a node.
fn to_node(&self, index: portgraph::NodeIndex) -> Self::Node;
}

impl HugrInternals for Hugr {
Expand All @@ -42,6 +51,8 @@ impl HugrInternals for Hugr {
where
Self: 'p;

type Node = Node;

#[inline]
fn portgraph(&self) -> Self::Portgraph<'_> {
&self.graph
Expand All @@ -53,21 +64,33 @@ impl HugrInternals for Hugr {
}

#[inline]
fn root_node(&self) -> Node {
fn root_node(&self) -> Self::Node {
self.root.into()
}

fn to_pg_index(&self, node: Self::Node) -> portgraph::NodeIndex {
node.pg_index()
}

fn to_node(&self, index: portgraph::NodeIndex) -> Self::Node {
index.into()
}
}

impl<T: HugrInternals> HugrInternals for &T {
type Portgraph<'p>
= T::Portgraph<'p>
where
Self: 'p;
type Node = T::Node;

delegate! {
to (**self) {
fn portgraph(&self) -> Self::Portgraph<'_>;
fn base_hugr(&self) -> &Hugr;
fn root_node(&self) -> Node;
fn root_node(&self) -> Self::Node;
fn to_pg_index(&self, node: Self::Node) -> portgraph::NodeIndex;
fn to_node(&self, index: portgraph::NodeIndex) -> Self::Node;
}
}
}
Expand All @@ -77,11 +100,15 @@ impl<T: HugrInternals> HugrInternals for &mut T {
= T::Portgraph<'p>
where
Self: 'p;
type Node = T::Node;

delegate! {
to (**self) {
fn portgraph(&self) -> Self::Portgraph<'_>;
fn base_hugr(&self) -> &Hugr;
fn root_node(&self) -> Node;
fn root_node(&self) -> Self::Node;
fn to_pg_index(&self, node: Self::Node) -> portgraph::NodeIndex;
fn to_node(&self, index: portgraph::NodeIndex) -> Self::Node;
}
}
}
Expand All @@ -91,11 +118,15 @@ impl<T: HugrInternals> HugrInternals for Rc<T> {
= T::Portgraph<'p>
where
Self: 'p;
type Node = T::Node;

delegate! {
to (**self) {
fn portgraph(&self) -> Self::Portgraph<'_>;
fn base_hugr(&self) -> &Hugr;
fn root_node(&self) -> Node;
fn root_node(&self) -> Self::Node;
fn to_pg_index(&self, node: Self::Node) -> portgraph::NodeIndex;
fn to_node(&self, index: portgraph::NodeIndex) -> Self::Node;
}
}
}
Expand All @@ -105,11 +136,15 @@ impl<T: HugrInternals> HugrInternals for Arc<T> {
= T::Portgraph<'p>
where
Self: 'p;
type Node = T::Node;

delegate! {
to (**self) {
fn portgraph(&self) -> Self::Portgraph<'_>;
fn base_hugr(&self) -> &Hugr;
fn root_node(&self) -> Node;
fn root_node(&self) -> Self::Node;
fn to_pg_index(&self, node: Self::Node) -> portgraph::NodeIndex;
fn to_node(&self, index: portgraph::NodeIndex) -> Self::Node;
}
}
}
Expand All @@ -119,11 +154,15 @@ impl<T: HugrInternals> HugrInternals for Box<T> {
= T::Portgraph<'p>
where
Self: 'p;
type Node = T::Node;

delegate! {
to (**self) {
fn portgraph(&self) -> Self::Portgraph<'_>;
fn base_hugr(&self) -> &Hugr;
fn root_node(&self) -> Node;
fn root_node(&self) -> Self::Node;
fn to_pg_index(&self, node: Self::Node) -> portgraph::NodeIndex;
fn to_node(&self, index: portgraph::NodeIndex) -> Self::Node;
}
}
}
Expand All @@ -133,19 +172,23 @@ impl<T: HugrInternals + ToOwned> HugrInternals for Cow<'_, T> {
= T::Portgraph<'p>
where
Self: 'p;
type Node = T::Node;

delegate! {
to self.as_ref() {
fn portgraph(&self) -> Self::Portgraph<'_>;
fn base_hugr(&self) -> &Hugr;
fn root_node(&self) -> Node;
fn root_node(&self) -> Self::Node;
fn to_pg_index(&self, node: Self::Node) -> portgraph::NodeIndex;
fn to_node(&self, index: portgraph::NodeIndex) -> Self::Node;
}
}
}
/// Trait for accessing the mutable internals of a Hugr(Mut).
///
/// Specifically, this trait lets you apply arbitrary modifications that may
/// invalidate the HUGR.
pub trait HugrMutInternals: RootTagged {
pub trait HugrMutInternals: RootTagged<Node = Node> {
/// Returns the Hugr at the base of a chain of views.
fn hugr_mut(&mut self) -> &mut Hugr;

Expand Down Expand Up @@ -269,7 +312,7 @@ pub trait HugrMutInternals: RootTagged {
}

/// Impl for non-wrapped Hugrs. Overwrites the recursive default-impls to directly use the hugr.
impl<T: RootTagged<RootHandle = Node> + AsMut<Hugr>> HugrMutInternals for T {
impl<T: RootTagged<RootHandle = Node, Node = Node> + AsMut<Hugr>> HugrMutInternals for T {
fn hugr_mut(&mut self) -> &mut Hugr {
self.as_mut()
}
Expand Down
4 changes: 2 additions & 2 deletions hugr-core/src/hugr/rewrite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ pub trait Rewrite {
/// Checks whether the rewrite would succeed on the specified Hugr.
/// If this call succeeds, [self.apply] should also succeed on the same `h`
/// If this calls fails, [self.apply] would fail with the same error.
fn verify(&self, h: &impl HugrView) -> Result<(), Self::Error>;
fn verify(&self, h: &impl HugrView<Node = Node>) -> Result<(), Self::Error>;

/// Mutate the specified Hugr, or fail with an error.
/// Returns [`Self::ApplyResult`] if successful.
Expand Down Expand Up @@ -58,7 +58,7 @@ impl<R: Rewrite> Rewrite for Transactional<R> {
type ApplyResult = R::ApplyResult;
const UNCHANGED_ON_FAILURE: bool = true;

fn verify(&self, h: &impl HugrView) -> Result<(), Self::Error> {
fn verify(&self, h: &impl HugrView<Node = Node>) -> Result<(), Self::Error> {
self.underlying.verify(h)
}

Expand Down
Loading
Loading