Skip to content

Commit

Permalink
add binary cross entropy operation
Browse files Browse the repository at this point in the history
  • Loading branch information
sc420 committed Dec 24, 2023
1 parent 2cd617f commit 5ce3f80
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 0 deletions.
15 changes: 15 additions & 0 deletions interactive-computational-graph/src/components/GraphContainer.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ import type AddNodeData from "../features/AddNodeData";
import {
ADD_DFDX_CODE,
ADD_F_CODE,
BINARY_CROSS_ENTROPY_DFDX_CODE,
BINARY_CROSS_ENTROPY_F_CODE,
COS_DFDX_CODE,
COS_F_CODE,
DIVIDE_DFDX_CODE,
Expand Down Expand Up @@ -275,6 +277,19 @@ const GraphContainer: FunctionComponent<GraphContainerProps> = ({
inputPorts: [new Port("y_t", false), new Port("y_e", false)],
helpText: "Calculate squared error $ (y_t - y_e)^2 $",
},
{
id: "binary_cross_entropy",
text: "Binary Cross-Entropy",
type: "SIMPLE",
namePrefix: "b",
operation: new Operation(
BINARY_CROSS_ENTROPY_F_CODE,
BINARY_CROSS_ENTROPY_DFDX_CODE,
),
inputPorts: [new Port("y_t", false), new Port("y_e", false)],
helpText:
"Calculate binary cross-entropy $ y_t * \\log(y_e) + (1 - y_t) * \\log(1 - y_e) $",
},
]);
const [nextNodeId, setNextNodeId] = useState<number>(0);
const [nodeNameBuilder] = useState<NodeNameBuilder>(new NodeNameBuilder());
Expand Down
102 changes: 102 additions & 0 deletions interactive-computational-graph/src/features/BuiltInCode.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1473,9 +1473,111 @@ function dfdx(fInputPortToNodes, fInputNodeToValues, xId) {
}
`;

const BINARY_CROSS_ENTROPY_F_CODE = `\
/**
* Calculates f().
* @param {Record<string, string[]>} fInputPortToNodes An object where the keys
* are port IDs and the values are node IDs of the connected input nodes.
* Example data for squared error:
* \`\`\`json
* {
* "y_t": ["0"],
* "y_e": ["1"]
* }
* \`\`\`
* @param {Record<string, string>} fInputNodeToValues An object where the keys
* are node IDs and the values are node values of the connected input nodes.
* Example data for squared error:
* \`\`\`json
* {
* "0": "0",
* "1": "0.5"
* }
* \`\`\`
* @returns {string} Evaluated f value. For example: if we consider
* the above example data, then the value is "-0.693" because
* f(y_t, y_e) = y_t * log(y_e) + (1 - y_t) * log(1 - y_e) =
* 0 * log(0.5) + (1 - 0) * log(1 - 0.5) ~= -0.693.
*/
function f(fInputPortToNodes, fInputNodeToValues) {
if (fInputPortToNodes.y_t.length !== 1) {
throw new Error("Should have exactly 1 input node for port y_t");
}
if (fInputPortToNodes.y_e.length !== 1) {
throw new Error("Should have exactly 1 input node for port y_e");
}
const yTrueInputNodeId = fInputPortToNodes.y_t[0];
const yEstimateInputNodeId = fInputPortToNodes.y_e[0];
const yTrue = parseFloat(fInputNodeToValues[yTrueInputNodeId]);
const yEstimate = parseFloat(fInputNodeToValues[yEstimateInputNodeId]);
const y = yTrue * Math.log(yEstimate) + (1 - yTrue) * Math.log(1 - yEstimate);
return \`\${y}\`;
}
`;

const BINARY_CROSS_ENTROPY_DFDX_CODE = `\
/**
* Calculates df/dx.
* @param {Record<string, string[]>} fInputPortToNodes An object where the keys
* are port IDs and the values are node IDs of the connected input nodes.
* Example data for squared error:
* \`\`\`json
* {
* "y_t": ["0"],
* "y_e": ["1"]
* }
* \`\`\`
* @param {Record<string, string>} fInputNodeToValues An object where the keys
* are node IDs and the values are node values of the connected input nodes.
* Example data for squared error:
* \`\`\`json
* {
* "0": "0",
* "1": "0.5"
* }
* \`\`\`
* @param {string} xId Node ID of x. Note that the framework will not call this
* function for the following cases:
* - x is a constant node (i.e., x will always be a variable)
* - x is the node of f (i.e., the derivative is always 1)
* - x is not on the forward/reverse differentiation path (i.e., gradient of x
* doesn't flow through f node)
* @returns {string} Evaluated derivative df/dy. For example, if we consider
* the above example data and assume xId is "0", then the value is "-1.386"
* since f(y_t, y_e) = y_t * log(y_e) + (1 - y_t) * log(1 - y_e) and
* df/d(y_t) = log(y_e) - log(1 - y_e) = log(0.5) - log(1 - 0.5) ~= -1.386.
*/
function dfdx(fInputPortToNodes, fInputNodeToValues, xId) {
if (fInputPortToNodes.y_t.length !== 1) {
throw new Error("Should have exactly 1 input node for port y_t");
}
if (fInputPortToNodes.y_e.length !== 1) {
throw new Error("Should have exactly 1 input node for port y_e");
}
const hasXInYTrue = fInputPortToNodes.y_t.includes(xId);
const hasXInYEstimate = fInputPortToNodes.y_e.includes(xId);
if (!hasXInYTrue && !hasXInYEstimate) {
return "0";
}
const yTrueInputNodeId = fInputPortToNodes.y_t[0];
const yEstimateInputNodeId = fInputPortToNodes.y_e[0];
const yTrue = parseFloat(fInputNodeToValues[yTrueInputNodeId]);
const yEstimate = parseFloat(fInputNodeToValues[yEstimateInputNodeId]);
let df = 0;
if (hasXInYTrue) {
df = Math.log(yEstimate) - Math.log(1 - yEstimate);
} else {
df = (yTrue - yEstimate) / (yEstimate - Math.pow(yEstimate, 2));
}
return \`\${df}\`;
}
`;

export {
ADD_DFDX_CODE,
ADD_F_CODE,
BINARY_CROSS_ENTROPY_DFDX_CODE,
BINARY_CROSS_ENTROPY_F_CODE,
COS_DFDX_CODE,
COS_F_CODE,
DIVIDE_DFDX_CODE,
Expand Down

0 comments on commit 5ce3f80

Please sign in to comment.