diff --git a/crates/burn-import/onnx-tests/tests/prelu/prelu.onnx b/crates/burn-import/onnx-tests/tests/prelu/prelu.onnx new file mode 100644 index 00000000000..d9644b84e59 Binary files /dev/null and b/crates/burn-import/onnx-tests/tests/prelu/prelu.onnx differ diff --git a/crates/burn-import/onnx-tests/tests/prelu/prelu.py b/crates/burn-import/onnx-tests/tests/prelu/prelu.py new file mode 100644 index 00000000000..f37030ad909 --- /dev/null +++ b/crates/burn-import/onnx-tests/tests/prelu/prelu.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 + +# used to generate model: prelu.onnx + +import torch +import torch.nn as nn + + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + self.relu1 = nn.PReLU() + + def forward(self, x): + x = self.relu1(x) + return x + + +def main(): + + # Set seed for reproducibility + torch.manual_seed(42) + + torch.set_printoptions(precision=8) + + # Export to onnx + model = Model() + model.eval() + device = torch.device("cpu") + + file_name = "prelu.onnx" + test_input = torch.randn(2, 3, device=device) + torch.onnx.export(model, test_input, file_name, + verbose=False, opset_version=16) + + print("Finished exporting model to {}".format(file_name)) + + # Output some test data for use in the test + print("Test input data of ones: {}".format(test_input)) + print("Test input data shape of ones: {}".format(test_input.shape)) + output = model.forward(test_input) + print("Test output data shape: {}".format(output.shape)) + + print("Test output: {}".format(output)) + + +if __name__ == '__main__': + main() + diff --git a/crates/burn-import/src/burn/node/base.rs b/crates/burn-import/src/burn/node/base.rs index 89b175de588..7505c8886cf 100644 --- a/crates/burn-import/src/burn/node/base.rs +++ b/crates/burn-import/src/burn/node/base.rs @@ -1,5 +1,6 @@ use super::layer_norm::LayerNormNode; use super::mask_where::WhereNode; +use super::prelu::PReluNode; use super::unsqueeze::UnsqueezeNode; use super::{ avg_pool2d::AvgPool2dNode, batch_norm::BatchNormNode, binary::BinaryNode, clip::ClipNode, @@ -85,6 +86,7 @@ pub enum Node { Conv1d(Conv1dNode), Conv2d(Conv2dNode), ConvTranspose2d(ConvTranspose2dNode), + PRelu(PReluNode), Dropout(DropoutNode), Gather(GatherNode), GlobalAvgPool(GlobalAvgPoolNode), @@ -111,6 +113,7 @@ macro_rules! match_all { Node::Conv1d(node) => $func(node), Node::Conv2d(node) => $func(node), Node::ConvTranspose2d(node) => $func(node), + Node::PRelu(node) => $func(node), Node::Dropout(node) => $func(node), Node::Gather(node) => $func(node), Node::GlobalAvgPool(node) => $func(node), @@ -147,6 +150,7 @@ impl Node { Node::Conv1d(_) => "conv1d", Node::Conv2d(_) => "conv2d", Node::ConvTranspose2d(_) => "conv_transpose2d", + Node::PRelu(_) => "prelu", Node::Dropout(_) => "dropout", Node::Gather(_) => "gather", Node::GlobalAvgPool(_) => "global_avg_pool", diff --git a/crates/burn-import/src/burn/node/mod.rs b/crates/burn-import/src/burn/node/mod.rs index 965652c7a2b..ae936bbadb8 100644 --- a/crates/burn-import/src/burn/node/mod.rs +++ b/crates/burn-import/src/burn/node/mod.rs @@ -17,6 +17,7 @@ pub(crate) mod linear; pub(crate) mod mask_where; pub(crate) mod matmul; pub(crate) mod max_pool2d; +pub(crate) mod prelu; pub(crate) mod reshape; pub(crate) mod unary; pub(crate) mod unsqueeze; diff --git a/crates/burn-import/src/burn/node/prelu.rs b/crates/burn-import/src/burn/node/prelu.rs new file mode 100644 index 00000000000..7eb84465898 --- /dev/null +++ b/crates/burn-import/src/burn/node/prelu.rs @@ -0,0 +1,100 @@ +use super::{Node, NodeCodegen, SerializationBackend}; +use crate::burn::{BurnImports, OtherType, Scope, TensorType, ToTokens, Type}; +use burn::{ + module::{Param, ParamId}, + nn::{PReluConfig, PReluRecord}, + record::{PrecisionSettings, Record}, + tensor::{DataSerialize, Tensor}, +}; +use proc_macro2::TokenStream; +use quote::quote; +use serde::Serialize; + +#[derive(Clone, Debug)] +pub struct PReluNode { + pub field: OtherType, + pub input: TensorType, + pub output: TensorType, + pub alpha: DataSerialize, + pub config: PReluConfig, +} + +impl PReluNode { + pub fn new>( + name: S, + input: TensorType, + output: TensorType, + alpha: DataSerialize, + config: PReluConfig, + ) -> Self { + Self { + field: OtherType::new( + name, + quote! { + PRelu + }, + ), + input, + output, + alpha, + config, + } + } +} + +impl NodeCodegen for PReluNode { + fn input_types(&self) -> Vec { + vec![Type::Tensor(self.input.clone())] + } + fn output_types(&self) -> Vec { + vec![Type::Tensor(self.output.clone())] + } + fn field_type(&self) -> Option { + Some(Type::Other(self.field.clone())) + } + + fn field_init(&self) -> Option { + let name = &self.field.name; + + let num_parameters = self.config.num_parameters.to_tokens(); + let alpha = self.config.alpha.to_tokens(); + let tokens = quote! { + let #name = PReluConfig::new(#num_parameters, #alpha) + .init(device); + }; + + Some(tokens) + } + + fn field_serialize(&self, serializer: S) -> Result { + let device = Default::default(); + let record = PReluRecord:: { + alpha: Param::initialized( + ParamId::new(), + Tensor::from_data(self.alpha.clone().convert(), &device), + ), + }; + + let item = Record::into_item::(record); + item.serialize(serializer) + } + + fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { + let input = scope.tensor_use_owned(&self.input, node_position); + let output = &self.output.name; + let field = &self.field.name; + + quote! { + let #output = self.#field.forward(#input); + } + } + fn register_imports(&self, imports: &mut BurnImports) { + imports.register("burn::nn::PRelu"); + imports.register("burn::nn::prelu::PRelu"); + imports.register("burn::nn::prelu::PReluConfig"); + } + + fn into_node(self) -> Node { + Node::PRelu(self) + } +} diff --git a/crates/burn-import/src/onnx/op_configuration.rs b/crates/burn-import/src/onnx/op_configuration.rs index da213d55e7d..e0c5cd16e2c 100644 --- a/crates/burn-import/src/onnx/op_configuration.rs +++ b/crates/burn-import/src/onnx/op_configuration.rs @@ -1,7 +1,7 @@ use burn::nn::{ conv::{Conv1dConfig, Conv2dConfig, ConvTranspose2dConfig}, pool::{AvgPool2dConfig, MaxPool2dConfig}, - BatchNormConfig, DropoutConfig, LayerNormConfig, LinearConfig, PaddingConfig1d, + BatchNormConfig, DropoutConfig, LayerNormConfig, LinearConfig, PReluConfig, PaddingConfig1d, PaddingConfig2d, }; @@ -120,6 +120,21 @@ pub fn max_pool2d_config(curr: &Node) -> MaxPool2dConfig { .with_padding(padding) .with_dilation([dilations[0] as usize, dilations[1] as usize]) } +pub fn prelu_config(curr: &Node) -> PReluConfig { + let mut alpha = 0.01; + let mut num_parameters = 0; + for (key, value) in curr.attrs.iter() { + match key.as_str() { + "alpha" => alpha = value.clone().into_f32(), + "num_parameters" => num_parameters = value.clone().into_i32(), + _ => {} + } + } + + PReluConfig::new() + .with_num_parameters(num_parameters as usize) + .with_alpha(alpha as f64) +} pub fn conv_transpose2d_config(curr: &Node) -> ConvTranspose2dConfig { let mut attrs = curr.attrs.clone(); diff --git a/crates/burn-import/src/onnx/to_burn.rs b/crates/burn-import/src/onnx/to_burn.rs index 51ebf8683b6..ba7d8fbe154 100644 --- a/crates/burn-import/src/onnx/to_burn.rs +++ b/crates/burn-import/src/onnx/to_burn.rs @@ -30,6 +30,7 @@ use crate::{ mask_where::WhereNode, matmul::MatmulNode, max_pool2d::MaxPool2dNode, + prelu::PReluNode, reshape::ReshapeNode, unary::UnaryNode, unsqueeze::UnsqueezeNode, @@ -236,6 +237,7 @@ impl OnnxGraph { NodeType::Conv1d => graph.register(Self::conv1d_conversion::(node)), NodeType::Conv2d => graph.register(Self::conv2d_conversion::(node)), NodeType::MaxPool2d => graph.register(Self::max_pool2d_conversion(node)), + NodeType::PRelu => graph.register(Self::prelu_conversion::(node)), NodeType::AveragePool2d => graph.register(Self::avg_pool_2d_conversion(node)), NodeType::MatMul => graph.register(Self::matmul_conversion(node)), NodeType::Neg => graph.register(Self::neg_conversion(node)), @@ -695,6 +697,14 @@ impl OnnxGraph { MaxPool2dNode::new(name, input, output, config) } + fn prelu_conversion(node: Node) -> PReluNode { + let input = node.inputs.first().unwrap().to_tensor_type(); + let output = node.outputs.first().unwrap().to_tensor_type(); + let weight = extract_data_serialize::(1, &node).unwrap(); + let config = prelu_config(&node); + let name = &node.name; + PReluNode::::new(name, input, output, weight, config) + } fn conv_transpose2d_conversion(node: Node) -> ConvTranspose2dNode { let input = node.inputs.first().unwrap().to_tensor_type(); let output = node.outputs.first().unwrap().to_tensor_type(); diff --git a/flake.lock b/flake.lock new file mode 100644 index 00000000000..982ecf0496f --- /dev/null +++ b/flake.lock @@ -0,0 +1,153 @@ +{ + "nodes": { + "flake-parts": { + "inputs": { + "nixpkgs-lib": [ + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1712014858, + "narHash": "sha256-sB4SWl2lX95bExY2gMFG5HIzvva5AVMJd4Igm+GpZNw=", + "owner": "hercules-ci", + "repo": "flake-parts", + "rev": "9126214d0a59633752a136528f5f3b9aa8565b7d", + "type": "github" + }, + "original": { + "owner": "hercules-ci", + "repo": "flake-parts", + "type": "github" + } + }, + "flake-utils": { + "locked": { + "lastModified": 1659877975, + "narHash": "sha256-zllb8aq3YO3h8B/U0/J1WBgAL8EX5yWf5pMj3G0NAmc=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "c0e246b9b83f637f4681389ecabcb2681b4f3af0", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "flake-utils_2": { + "inputs": { + "systems": "systems" + }, + "locked": { + "lastModified": 1705309234, + "narHash": "sha256-uNRRNRKmJyCRC/8y1RqBkqWBLM034y4qN7EprSdmgyA=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "1ef2e671c3b0c19053962c07dbda38332dcebf26", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "nixgl": { + "inputs": { + "flake-utils": "flake-utils", + "nixpkgs": "nixpkgs" + }, + "locked": { + "lastModified": 1713543440, + "narHash": "sha256-lnzZQYG0+EXl/6NkGpyIz+FEOc/DSEG57AP1VsdeNrM=", + "owner": "nix-community", + "repo": "nixGL", + "rev": "310f8e49a149e4c9ea52f1adf70cdc768ec53f8a", + "type": "github" + }, + "original": { + "owner": "nix-community", + "repo": "nixGL", + "type": "github" + } + }, + "nixpkgs": { + "locked": { + "lastModified": 1660551188, + "narHash": "sha256-a1LARMMYQ8DPx1BgoI/UN4bXe12hhZkCNqdxNi6uS0g=", + "owner": "nixos", + "repo": "nixpkgs", + "rev": "441dc5d512153039f19ef198e662e4f3dbb9fd65", + "type": "github" + }, + "original": { + "owner": "nixos", + "repo": "nixpkgs", + "type": "github" + } + }, + "nixpkgs_2": { + "locked": { + "lastModified": 1714253743, + "narHash": "sha256-mdTQw2XlariysyScCv2tTE45QSU9v/ezLcHJ22f0Nxc=", + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "58a1abdbae3217ca6b702f03d3b35125d88a2994", + "type": "github" + }, + "original": { + "owner": "NixOS", + "ref": "nixos-unstable", + "repo": "nixpkgs", + "type": "github" + } + }, + "root": { + "inputs": { + "flake-parts": "flake-parts", + "nixgl": "nixgl", + "nixpkgs": "nixpkgs_2", + "rust-overlay": "rust-overlay" + } + }, + "rust-overlay": { + "inputs": { + "flake-utils": "flake-utils_2", + "nixpkgs": [ + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1714529851, + "narHash": "sha256-YMKJW880f7LHXVRzu93xa6Ek+QLECIu0IRQbXbzZe38=", + "owner": "oxalica", + "repo": "rust-overlay", + "rev": "9ca720fdcf7865385ae3b93ecdf65f1a64cb475e", + "type": "github" + }, + "original": { + "owner": "oxalica", + "repo": "rust-overlay", + "type": "github" + } + }, + "systems": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + } + }, + "root": "root", + "version": 7 +} diff --git a/flake.nix b/flake.nix new file mode 100644 index 00000000000..3e1b485d161 --- /dev/null +++ b/flake.nix @@ -0,0 +1,65 @@ +{ + inputs = { + nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable"; + flake-parts = { + url = "github:hercules-ci/flake-parts"; + inputs.nixpkgs-lib.follows = "nixpkgs"; + }; + rust-overlay = { + url = "github:oxalica/rust-overlay"; + inputs.nixpkgs.follows = "nixpkgs"; + }; + nixgl = { + url = "github:nix-community/nixGL"; + # inputs.nixpkgs.follows = "nixpkgs"; + }; + }; + + outputs = {nixpkgs, ...} @ inputs: + inputs.flake-parts.lib.mkFlake {inherit inputs;} { + systems = ["x86_64-linux"]; + perSystem = { + config, + system, + lib, + ... + }: let + overlays = [ + (import inputs.rust-overlay) + ]; + pkgs = import nixpkgs { + inherit system overlays; + }; + cudaPkg = pkgs.cudaPackages.cudatoolkit.override {cudaVersion = "12.1";}; + in { + devShells.default = pkgs.mkShell rec { + packages = with pkgs; [ + pkg-config + openssl + glxinfo + vscode-extensions.llvm-org.lldb-vscode + taplo + mdbook + glib-networking + cudaPkg + inputs.nixgl.packages.${pkgs.system}.default + inputs.nixgl.packages.${pkgs.system}.nixVulkanNvidia + cudaPackages.cudnn + rust-bin.stable.latest.default + typos + libxkbcommon + libGL + wayland + vulkan-tools + vulkan-loader + # flamegraph + # samply + ]; + LD_LIBRARY_PATH = "${lib.makeLibraryPath packages}:/run/opengl-driver-32/lib:${pkgs.libGL}/lib:${cudaPkg}/lib:${pkgs.wayland}/lib"; + TORCH_CUDA_VERSION = "cu121"; + PATH = "~/.cargo/bin:$PATH"; + }; + formatter = pkgs.alejandra; + }; + }; +}