Skip to content

Commit

Permalink
Handle visibility in cube (#1929)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard authored Jun 26, 2024
1 parent d772a1c commit f9ec2e1
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 9 deletions.
26 changes: 19 additions & 7 deletions crates/burn-cube-macros/src/codegen_type/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ struct TypeCodegen {
name_expand: syn::Ident,
fields: Vec<syn::Field>,
generics: GenericsCodegen,
vis: syn::Visibility,
}

impl TypeCodegen {
pub fn expand_ty(&self) -> proc_macro2::TokenStream {
let mut fields = quote::quote! {};
Expand All @@ -19,17 +21,19 @@ impl TypeCodegen {
for field in self.fields.iter() {
let ident = &field.ident;
let ty = &field.ty;
let vis = &field.vis;

fields.extend(quote! {
#ident: <#ty as CubeType>::ExpandType,
#vis #ident: <#ty as CubeType>::ExpandType,
});
}

let generics = self.generics.type_definitions();
let vis = &self.vis;

quote! {
#[derive(Clone)]
struct #name #generics {
#vis struct #name #generics {
#fields
}
}
Expand All @@ -42,9 +46,10 @@ impl TypeCodegen {
for field in self.fields.iter() {
let ident = &field.ident;
let ty = &field.ty;
let vis = &field.vis;

fields.extend(quote! {
#ident: <#ty as LaunchArg>::RuntimeArg<'a, R>,
#vis #ident: <#ty as LaunchArg>::RuntimeArg<'a, R>,
});
}

Expand All @@ -65,9 +70,10 @@ impl TypeCodegen {
for field in self.fields.iter() {
let ident = &field.ident;
let ty = &field.ty;
let vis = &field.vis;

args.extend(quote! {
#ident: <#ty as LaunchArg>::RuntimeArg<'a, R>,
#vis #ident: <#ty as LaunchArg>::RuntimeArg<'a, R>,
});
fields.extend(quote! {
#ident,
Expand All @@ -76,11 +82,12 @@ impl TypeCodegen {

let generics_impl = self.generics.all_definitions();
let generics_use = self.generics.all_in_use();
let vis = &self.vis;

quote! {
impl #generics_impl #name #generics_use {
/// New kernel
pub fn new(#args) -> Self {
#vis fn new(#args) -> Self {
Self {
#fields
}
Expand Down Expand Up @@ -137,12 +144,13 @@ impl TypeCodegen {
for field in self.fields.iter() {
let ident = &field.ident;
let ty = &field.ty;
let vis = &field.vis;

body_input.extend(quote! {
#ident: <#ty as LaunchArg>::compile_input(builder, vectorization),
#vis #ident: <#ty as LaunchArg>::compile_input(builder, vectorization),
});
body_output.extend(quote! {
#ident: <#ty as LaunchArg>::compile_output(builder, vectorization),
#vis #ident: <#ty as LaunchArg>::compile_output(builder, vectorization),
});
}

Expand Down Expand Up @@ -194,9 +202,12 @@ impl TypeCodegen {
pub(crate) fn generate_cube_type(ast: &syn::DeriveInput, with_launch: bool) -> TokenStream {
let name = ast.ident.clone();
let generics = ast.generics.clone();
let visibility = ast.vis.clone();

let name_string = name.to_string();
let name_expand = Ident::new(format!("{}Expand", name_string).as_str(), name.span());
let name_launch = Ident::new(format!("{}Launch", name_string).as_str(), name.span());

let mut fields = Vec::new();

match &ast.data {
Expand All @@ -215,6 +226,7 @@ pub(crate) fn generate_cube_type(ast: &syn::DeriveInput, with_launch: bool) -> T
name_expand,
fields,
generics: GenericsCodegen::new(generics),
vis: visibility,
};

let expand_ty = codegen.expand_ty();
Expand Down
5 changes: 3 additions & 2 deletions crates/burn-cube-macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ fn codegen_cube(
func: &syn::ItemFn,
variable_tracker: &mut VariableTracker,
) -> Result<proc_macro2::TokenStream, proc_macro2::TokenStream> {
let signature = expand_sig(&func.sig, variable_tracker);
let signature = expand_sig(&func.sig, &func.vis, variable_tracker);
let mut body = quote::quote! {};

for statement in func.block.stmts.iter() {
Expand Down Expand Up @@ -148,6 +148,7 @@ fn codegen_cube(

fn expand_sig(
sig: &syn::Signature,
visibility: &syn::Visibility,
variable_tracker: &mut VariableTracker,
) -> proc_macro2::TokenStream {
let mut inputs = quote::quote!();
Expand Down Expand Up @@ -188,6 +189,6 @@ fn expand_sig(

quote::quote! {
/// Expanded Cube function
pub fn #ident #generics (context: &mut burn_cube::frontend::CubeContext, #inputs) -> #output
#visibility fn #ident #generics (context: &mut burn_cube::frontend::CubeContext, #inputs) -> #output
}
}

0 comments on commit f9ec2e1

Please sign in to comment.