Skip to content

Commit

Permalink
use zero-based IDs
Browse files Browse the repository at this point in the history
  • Loading branch information
sc420 committed Dec 22, 2023
1 parent 065e96f commit 144efb1
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 92 deletions.
158 changes: 79 additions & 79 deletions interactive-computational-graph/src/components/GraphContainer.test.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,12 @@ it("edges and add node itself should be removed after removing add node", () =>
fireEvent.click(sumItem);

// Connect from the constant nodes to the add node
connectEdge("1", "output", "3", "a");
connectEdge("2", "output", "3", "b");
connectEdge("0", "output", "2", "a");
connectEdge("1", "output", "2", "b");

// Remove the add node
removeEdge(["reactflow__edge-1output-3a", "reactflow__edge-2output-3b"]);
removeNode(["3"]);
removeEdge(["reactflow__edge-0output-2a", "reactflow__edge-1output-2b"]);
removeNode(["2"]);

expect(screen.getByText("c_1")).toBeInTheDocument();
expect(screen.getByText("c_1")).toBeInTheDocument();
Expand All @@ -107,12 +107,12 @@ it("edges and sum node itself should be removed after removing sum node", () =>
fireEvent.click(sumItem);

// Connect from the constant nodes to the sum node
connectEdge("1", "output", "3", "x_i");
connectEdge("2", "output", "3", "x_i");
connectEdge("0", "output", "2", "x_i");
connectEdge("1", "output", "2", "x_i");

// Remove the sum node
removeEdge(["reactflow__edge-1output-3x_i", "reactflow__edge-2output-3x_i"]);
removeNode(["3"]);
removeEdge(["reactflow__edge-0output-2x_i", "reactflow__edge-1output-2x_i"]);
removeNode(["2"]);

expect(screen.getByText("c_1")).toBeInTheDocument();
expect(screen.getByText("c_2")).toBeInTheDocument();
Expand All @@ -137,12 +137,12 @@ it("can connect same node to multiple ports, then remove the connections", () =>
fireEvent.click(addItem);

// Connect from the variable node to the add node
connectEdge("1", "output", "2", "a");
connectEdge("1", "output", "2", "b");
connectEdge("0", "output", "1", "a");
connectEdge("0", "output", "1", "b");

// Disconnect from the variable node to the add node
removeEdge(["reactflow__edge-1output-2a"]);
removeEdge(["reactflow__edge-1output-2b"]);
removeEdge(["reactflow__edge-0output-1a"]);
removeEdge(["reactflow__edge-0output-1b"]);
});

it("input text fields should hide/show properly", () => {
Expand All @@ -157,22 +157,22 @@ it("input text fields should hide/show properly", () => {
const sumItem = screen.getByText("Sum");
fireEvent.click(sumItem);

expect(screen.getByTestId("input-item-3-x_i")).toBeInTheDocument();
expect(screen.getByTestId("input-item-2-x_i")).toBeInTheDocument();

// Connect from the 1st constant node to the sum node
connectEdge("1", "output", "3", "x_i");
connectEdge("0", "output", "2", "x_i");

expect(screen.queryByTestId("input-item-3-x_i")).toBeNull();
expect(screen.queryByTestId("input-item-2-x_i")).toBeNull();

// Connect from the 2nd constant node to the sum node
connectEdge("2", "output", "3", "x_i");
connectEdge("1", "output", "2", "x_i");

expect(screen.queryByTestId("input-item-3-x_i")).toBeNull();
expect(screen.queryByTestId("input-item-2-x_i")).toBeNull();

// Disconnect from the constant nodes to the sum node
removeEdge(["reactflow__edge-1output-3x_i", "reactflow__edge-2output-3x_i"]);
removeEdge(["reactflow__edge-0output-2x_i", "reactflow__edge-1output-2x_i"]);

expect(screen.getByTestId("input-item-3-x_i")).toBeInTheDocument();
expect(screen.getByTestId("input-item-2-x_i")).toBeInTheDocument();
});

it("should show error message when connecting the same edge twice", () => {
Expand All @@ -187,9 +187,9 @@ it("should show error message when connecting the same edge twice", () => {
fireEvent.click(sumItem);

// Connect from the 1st constant node to the add node port a twice
connectEdge("1", "output", "2", "x_i");
connectEdge("0", "output", "1", "x_i");
expect(screen.queryByRole("alert")).toBeNull();
connectEdge("1", "output", "2", "x_i");
connectEdge("0", "output", "1", "x_i");

const snackbar = screen.getByRole("alert");
expect(snackbar).toBeInTheDocument();
Expand All @@ -211,8 +211,8 @@ it("should show error message when connecting to the single-connection port", ()
fireEvent.click(sumItem);

// Connect from the 2nd constant node to the add node port a
connectEdge("1", "output", "3", "a");
connectEdge("2", "output", "3", "a");
connectEdge("0", "output", "2", "a");
connectEdge("1", "output", "2", "a");

const snackbar = screen.getByRole("alert");
expect(snackbar).toBeInTheDocument();
Expand All @@ -229,7 +229,7 @@ it("should show error message when causing a cycle", () => {
fireEvent.click(sumItem);

// Connect from the output of sum node to the input of sum node
connectEdge("1", "output", "1", "x_i");
connectEdge("0", "output", "0", "x_i");

const snackbar = screen.getByRole("alert");
expect(snackbar).toBeInTheDocument();
Expand All @@ -251,15 +251,15 @@ it("derivative target should reset when the target node is removed", () => {
fireEvent.click(sumItem);

// Connect from the constant nodes to the sum node
connectEdge("1", "output", "3", "x_i");
connectEdge("2", "output", "3", "x_i");
connectEdge("0", "output", "2", "x_i");
connectEdge("1", "output", "2", "x_i");

// Select the sum node as the derivative target
setDerivativeTarget("s_1");

// Remove the sum node
removeEdge(["reactflow__edge-1output-3x_i", "reactflow__edge-2output-3x_i"]);
removeNode(["3"]);
removeEdge(["reactflow__edge-0output-2x_i", "reactflow__edge-1output-2x_i"]);
removeNode(["2"]);

expect(getDerivativeTarget()).toBe("");
});
Expand All @@ -277,14 +277,14 @@ it("derivative target name should update when the node name is updated", () => {
fireEvent.click(sumItem);

// Connect from the constant nodes to the sum node
connectEdge("1", "output", "3", "x_i");
connectEdge("2", "output", "3", "x_i");
connectEdge("0", "output", "2", "x_i");
connectEdge("1", "output", "2", "x_i");

// Select the sum node as the derivative target
setDerivativeTarget("s_1");

// Update the node name of the sum node
setNodeName("3", "s_1");
setNodeName("2", "s_1");

expect(getDerivativeTarget()).toBe("s_1");
});
Expand All @@ -295,12 +295,12 @@ it("outputs should be set correctly after adding the nodes", () => {
// Add the nodes
const addItem = screen.getByText("Add");
const cosItem = screen.getByText("Cos");
fireEvent.click(addItem); // id=1
fireEvent.click(cosItem); // id=2
fireEvent.click(addItem); // id=0
fireEvent.click(cosItem); // id=1

// Check the output values
expect(getOutputItemValue("1", "VALUE")).toBe("0");
expect(getOutputItemValue("2", "VALUE")).toBe("1");
expect(getOutputItemValue("0", "VALUE")).toBe("0");
expect(getOutputItemValue("1", "VALUE")).toBe("1");
});

// It uses example from https://colah.github.io/posts/2015-08-Backprop/
Expand All @@ -311,121 +311,121 @@ it("outputs should change when derivative mode/target is changed", () => {
const variableItem = screen.getByText("Variable");
const addItem = screen.getByText("Add");
const multiplyItem = screen.getByText("Multiply");
fireEvent.click(variableItem); // id=0
fireEvent.click(variableItem); // id=1
fireEvent.click(variableItem); // id=2
fireEvent.click(addItem); // id=2
fireEvent.click(addItem); // id=3
fireEvent.click(addItem); // id=4
fireEvent.click(multiplyItem); // id=5
fireEvent.click(multiplyItem); // id=4

// Connect from the variable nodes to add nodes
connectEdge("0", "output", "2", "a");
connectEdge("1", "output", "2", "b");
connectEdge("1", "output", "3", "a");
connectEdge("2", "output", "3", "b");
connectEdge("2", "output", "4", "a");

// Set add node input values
setInputItemValue("1", "value", "2");
setInputItemValue("2", "value", "1");
setInputItemValue("4", "b", "1");
setInputItemValue("0", "value", "2");
setInputItemValue("1", "value", "1");
setInputItemValue("3", "b", "1");

// Connect from the add nodes to multiply node
connectEdge("3", "output", "5", "a");
connectEdge("4", "output", "5", "b");
connectEdge("2", "output", "4", "a");
connectEdge("3", "output", "4", "b");

// Select the multiply node as the derivative target
setDerivativeTarget("m_1");

// Check the output values
expect(getOutputItemValue("3", "VALUE")).toBe("3");
expect(getOutputItemValue("4", "VALUE")).toBe("2");
expect(getOutputItemValue("5", "VALUE")).toBe("6");
expect(getOutputItemValue("2", "VALUE")).toBe("3");
expect(getOutputItemValue("3", "VALUE")).toBe("2");
expect(getOutputItemValue("4", "VALUE")).toBe("6");

// Check the derivative labels
expect(getOutputItemLabelText("1", "DERIVATIVE")).toBe(
expect(getOutputItemLabelText("0", "DERIVATIVE")).toBe(
"\\displaystyle \\frac{\\partial{m_1}}{\\partial{v_1}}=",
);
expect(getOutputItemLabelText("2", "DERIVATIVE")).toBe(
expect(getOutputItemLabelText("1", "DERIVATIVE")).toBe(
"\\displaystyle \\frac{\\partial{m_1}}{\\partial{v_2}}=",
);
expect(getOutputItemLabelText("3", "DERIVATIVE")).toBe(
expect(getOutputItemLabelText("2", "DERIVATIVE")).toBe(
"\\displaystyle \\frac{\\partial{m_1}}{\\partial{a_1}}=",
);
expect(getOutputItemLabelText("4", "DERIVATIVE")).toBe(
expect(getOutputItemLabelText("3", "DERIVATIVE")).toBe(
"\\displaystyle \\frac{\\partial{m_1}}{\\partial{a_2}}=",
);
expect(getOutputItemLabelText("5", "DERIVATIVE")).toBe(
expect(getOutputItemLabelText("4", "DERIVATIVE")).toBe(
"\\displaystyle \\frac{\\partial{m_1}}{\\partial{m_1}}=",
);

// Check the derivative values
expect(getOutputItemValue("1", "DERIVATIVE")).toBe("2");
expect(getOutputItemValue("2", "DERIVATIVE")).toBe("5");
expect(getOutputItemValue("3", "DERIVATIVE")).toBe("2");
expect(getOutputItemValue("4", "DERIVATIVE")).toBe("3");
expect(getOutputItemValue("5", "DERIVATIVE")).toBe("1");
expect(getOutputItemValue("0", "DERIVATIVE")).toBe("2");
expect(getOutputItemValue("1", "DERIVATIVE")).toBe("5");
expect(getOutputItemValue("2", "DERIVATIVE")).toBe("2");
expect(getOutputItemValue("3", "DERIVATIVE")).toBe("3");
expect(getOutputItemValue("4", "DERIVATIVE")).toBe("1");

// Change the differentiation mode to forward mode
toggleDifferentiationMode();

// Check the output values
expect(getOutputItemValue("3", "VALUE")).toBe("3");
expect(getOutputItemValue("4", "VALUE")).toBe("2");
expect(getOutputItemValue("5", "VALUE")).toBe("6");
expect(getOutputItemValue("2", "VALUE")).toBe("3");
expect(getOutputItemValue("3", "VALUE")).toBe("2");
expect(getOutputItemValue("4", "VALUE")).toBe("6");

// Check the derivative labels
expect(getOutputItemLabelText("1", "DERIVATIVE")).toBe(
expect(getOutputItemLabelText("0", "DERIVATIVE")).toBe(
"\\displaystyle \\frac{\\partial{v_1}}{\\partial{m_1}}=",
);
expect(getOutputItemLabelText("2", "DERIVATIVE")).toBe(
expect(getOutputItemLabelText("1", "DERIVATIVE")).toBe(
"\\displaystyle \\frac{\\partial{v_2}}{\\partial{m_1}}=",
);
expect(getOutputItemLabelText("3", "DERIVATIVE")).toBe(
expect(getOutputItemLabelText("2", "DERIVATIVE")).toBe(
"\\displaystyle \\frac{\\partial{a_1}}{\\partial{m_1}}=",
);
expect(getOutputItemLabelText("4", "DERIVATIVE")).toBe(
expect(getOutputItemLabelText("3", "DERIVATIVE")).toBe(
"\\displaystyle \\frac{\\partial{a_2}}{\\partial{m_1}}=",
);
expect(getOutputItemLabelText("5", "DERIVATIVE")).toBe(
expect(getOutputItemLabelText("4", "DERIVATIVE")).toBe(
"\\displaystyle \\frac{\\partial{m_1}}{\\partial{m_1}}=",
);

// Check the derivative values
expect(getOutputItemValue("0", "DERIVATIVE")).toBe("0");
expect(getOutputItemValue("1", "DERIVATIVE")).toBe("0");
expect(getOutputItemValue("2", "DERIVATIVE")).toBe("0");
expect(getOutputItemValue("3", "DERIVATIVE")).toBe("0");
expect(getOutputItemValue("4", "DERIVATIVE")).toBe("0");
expect(getOutputItemValue("5", "DERIVATIVE")).toBe("1");
expect(getOutputItemValue("4", "DERIVATIVE")).toBe("1");

// Select the second variable node as the derivative target
setDerivativeTarget("v_2");

// Check the output values
expect(getOutputItemValue("3", "VALUE")).toBe("3");
expect(getOutputItemValue("4", "VALUE")).toBe("2");
expect(getOutputItemValue("5", "VALUE")).toBe("6");
expect(getOutputItemValue("2", "VALUE")).toBe("3");
expect(getOutputItemValue("3", "VALUE")).toBe("2");
expect(getOutputItemValue("4", "VALUE")).toBe("6");

// Check the derivative labels
expect(getOutputItemLabelText("1", "DERIVATIVE")).toBe(
expect(getOutputItemLabelText("0", "DERIVATIVE")).toBe(
"\\displaystyle \\frac{\\partial{v_1}}{\\partial{v_2}}=",
);
expect(getOutputItemLabelText("2", "DERIVATIVE")).toBe(
expect(getOutputItemLabelText("1", "DERIVATIVE")).toBe(
"\\displaystyle \\frac{\\partial{v_2}}{\\partial{v_2}}=",
);
expect(getOutputItemLabelText("3", "DERIVATIVE")).toBe(
expect(getOutputItemLabelText("2", "DERIVATIVE")).toBe(
"\\displaystyle \\frac{\\partial{a_1}}{\\partial{v_2}}=",
);
expect(getOutputItemLabelText("4", "DERIVATIVE")).toBe(
expect(getOutputItemLabelText("3", "DERIVATIVE")).toBe(
"\\displaystyle \\frac{\\partial{a_2}}{\\partial{v_2}}=",
);
expect(getOutputItemLabelText("5", "DERIVATIVE")).toBe(
expect(getOutputItemLabelText("4", "DERIVATIVE")).toBe(
"\\displaystyle \\frac{\\partial{m_1}}{\\partial{v_2}}=",
);

// Check the derivative values
expect(getOutputItemValue("1", "DERIVATIVE")).toBe("0");
expect(getOutputItemValue("0", "DERIVATIVE")).toBe("0");
expect(getOutputItemValue("1", "DERIVATIVE")).toBe("1");
expect(getOutputItemValue("2", "DERIVATIVE")).toBe("1");
expect(getOutputItemValue("3", "DERIVATIVE")).toBe("1");
expect(getOutputItemValue("4", "DERIVATIVE")).toBe("1");
expect(getOutputItemValue("5", "DERIVATIVE")).toBe("5");
expect(getOutputItemValue("4", "DERIVATIVE")).toBe("5");
});

const renderGraphContainer = (): void => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,9 +154,9 @@ const GraphContainer: FunctionComponent<GraphContainerProps> = ({
helpText: "Calculate squared error $ (y_t - y_e)^2 $",
},
]);
const [nextNodeId, setNextNodeId] = useState<number>(1);
const [nextNodeId, setNextNodeId] = useState<number>(0);
const [nodeNameBuilder] = useState<NodeNameBuilder>(new NodeNameBuilder());
const [nextOperationId, setNextOperationId] = useState<number>(1);
const [nextOperationId, setNextOperationId] = useState<number>(0);

// Feature panel states
const [explainDerivativeData, setExplainDerivativeData] = useState<
Expand Down
Loading

0 comments on commit 144efb1

Please sign in to comment.