-
Notifications
You must be signed in to change notification settings - Fork 507
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* feat: add topk operation * fix test * fix: update IOEntry::Node j in the input_name map * fix: only run on macos * chore: clean up * chore: updated supported ops * cleanup * cleanup * address feedback * usize cast in config and add topk_smallest * EOD * EOD cleanup * minor update * run checks and other fixes * Requested changes * Remove extra top_k onnx file --------- Co-authored-by: Guillaume Lagrange <lagrange.guillaume.1@gmail.com>
- Loading branch information
Showing
13 changed files
with
319 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
import numpy as np | ||
import onnx | ||
from onnx import helper, TensorProto | ||
|
||
# Define the input tensor | ||
X = np.array([[0, 1, 2, 3], | ||
[4, 5, 6, 7], | ||
[8, 9, 10, 11]], dtype=np.float32) | ||
|
||
# Define the value of K | ||
k = 3 | ||
K = np.array([k], dtype=np.int64) | ||
axis = 1 | ||
new_dims = [X.shape[0], k] | ||
|
||
def create_model(op_set_version: int): | ||
input_tensors = [helper.make_tensor_value_info('X', TensorProto.FLOAT, X.shape)] | ||
|
||
output_tensors = [ | ||
helper.make_tensor_value_info('Values', TensorProto.FLOAT, new_dims), | ||
helper.make_tensor_value_info('Indices', TensorProto.INT32, new_dims) | ||
] | ||
|
||
# Create the TopK node | ||
if op_set_version > 1: | ||
node = helper.make_node( | ||
'TopK', | ||
inputs=['X', 'K'], | ||
outputs=['Values', 'Indices'], | ||
axis=axis, # Axis along which to find the top K elements | ||
) | ||
input_tensors.append(helper.make_tensor_value_info('K', TensorProto.INT32, K.shape)) | ||
else: | ||
node = helper.make_node( | ||
'TopK', | ||
inputs=['X'], | ||
outputs=['Values', 'Indices'], | ||
axis=axis, # Axis along which to find the top K elements | ||
k=k | ||
) | ||
|
||
# Create the graph | ||
graph = helper.make_graph( | ||
nodes = [node], | ||
name = 'TopKGraph', | ||
inputs = input_tensors, | ||
outputs = output_tensors, | ||
# Uncomment when initializers are supported. Currently we can't test opset 10/11 since the code will require a k value to be initialized for testing. | ||
#initializer = [ | ||
# helper.make_tensor('X', TensorProto.FLOAT, X.shape, X), | ||
# helper.make_tensor('K', TensorProto.INT64, [1], [k]), | ||
#] | ||
) | ||
|
||
# Create the model | ||
model = helper.make_model( | ||
graph, | ||
ir_version=8, | ||
opset_imports=[onnx.helper.make_operatorsetid("", op_set_version)] | ||
) | ||
# Check the model | ||
onnx.checker.check_model(model) | ||
|
||
# Save the model to a file | ||
onnx.save(model, f'top_k_opset_{op_set_version}.onnx') | ||
print(f"Model saved to top_k_opset_{op_set_version}.onnx") | ||
|
||
def main(): | ||
# Uncomment when initializers are supported. | ||
# for op_set_version in [1, 10, 11]: | ||
for op_set_version in [1]: | ||
create_model(op_set_version) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
use super::{Node, NodeCodegen}; | ||
use crate::burn::{Scope, TensorType, Type}; | ||
use burn::config::Config; | ||
use burn::record::PrecisionSettings; | ||
use proc_macro2::TokenStream; | ||
use quote::{quote, ToTokens}; | ||
|
||
#[derive(Config, Debug)] | ||
pub struct TopKConfig { | ||
pub axis: usize, | ||
pub k: usize, | ||
} | ||
|
||
#[derive(Debug, Clone, new)] | ||
pub struct TopKNode { | ||
pub input: TensorType, | ||
pub outputs: Vec<TensorType>, | ||
pub config: TopKConfig, | ||
} | ||
|
||
impl<PS: PrecisionSettings> NodeCodegen<PS> for TopKNode { | ||
fn output_types(&self) -> Vec<Type> { | ||
self.outputs | ||
.iter() | ||
.map(|t| Type::Tensor(t.clone())) | ||
.collect() | ||
} | ||
|
||
fn input_types(&self) -> Vec<Type> { | ||
vec![Type::Tensor(self.input.clone())] | ||
} | ||
|
||
fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { | ||
let axis = self.config.axis.to_token_stream(); | ||
let k = self.config.k.to_token_stream(); | ||
|
||
let input = scope.tensor_use_owned(&self.input, node_position); | ||
let values_output = &self.outputs[0].name; | ||
let indices_output = &self.outputs[1].name; | ||
|
||
quote! { | ||
let (#values_output, #indices_output) = #input.topk_with_indices(#k, #axis); | ||
} | ||
} | ||
|
||
fn into_node(self) -> Node<PS> { | ||
Node::TopK(self) | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use burn::record::FullPrecisionSettings; | ||
|
||
use super::*; | ||
use crate::burn::{ | ||
graph::BurnGraph, | ||
node::{test::assert_tokens, top_k::TopKNode}, | ||
TensorType, | ||
}; | ||
|
||
#[test] | ||
fn test_codegen_nodes() { | ||
let mut graph = BurnGraph::<FullPrecisionSettings>::default(); | ||
let config = TopKConfig::new(1, 3); | ||
|
||
graph.register(TopKNode::new( | ||
TensorType::new_float("input_tensor", 4), | ||
vec![ | ||
TensorType::new_float("values_tensor", 4), | ||
TensorType::new_int("indices_tensor", 4), | ||
], | ||
config, | ||
)); | ||
|
||
graph.register_input_output( | ||
vec!["input_tensor".to_string()], | ||
vec!["values_tensor".to_string(), "indices_tensor".to_string()], | ||
); | ||
|
||
let expected = quote! { | ||
use burn::tensor::Int; | ||
use burn::{ | ||
module::Module, | ||
tensor::{backend::Backend, Tensor}, | ||
}; | ||
|
||
#[derive(Module, Debug)] | ||
pub struct Model<B: Backend> { | ||
phantom: core::marker::PhantomData<B>, | ||
device: burn::module::Ignored<B::Device>, | ||
} | ||
|
||
impl<B: Backend> Model <B> { | ||
#[allow(unused_variables)] | ||
pub fn new(device: &B::Device) -> Self { | ||
Self { | ||
phantom: core::marker::PhantomData, | ||
device: burn::module::Ignored(device.clone()), | ||
} | ||
} | ||
#[allow(clippy::let_and_return, clippy::approx_constant)] | ||
pub fn forward(&self, input_tensor: Tensor<B, 4>) -> (Tensor<B, 4>, Tensor<B, 4, Int>) { | ||
let (values_tensor, indices_tensor) = input_tensor.topk_with_indices(3usize, 1usize); | ||
(values_tensor, indices_tensor) | ||
} | ||
} | ||
}; | ||
|
||
assert_tokens(graph.codegen(), expected); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.