diff --git a/bin/export-model-arch/src/export_model_arch.cc b/bin/export-model-arch/src/export_model_arch.cc index 1c2dfd6ea..4a5ef714e 100644 --- a/bin/export-model-arch/src/export_model_arch.cc +++ b/bin/export-model-arch/src/export_model_arch.cc @@ -3,6 +3,7 @@ #include "export_model_arch/json_sp_model_export.dtg.h" #include "models/bert/bert.h" #include "models/candle_uno/candle_uno.h" +#include "models/dlrm/dlrm.h" #include "models/inception_v3/inception_v3.h" #include "models/split_test/split_test.h" #include "models/transformer/transformer.h" @@ -69,6 +70,8 @@ tl::expected return get_candle_uno_computation_graph(get_default_candle_uno_config()); } else if (model_name == "bert") { return get_bert_computation_graph(get_default_bert_config()); + } else if (model_name == "dlrm") { + return get_dlrm_computation_graph(get_default_dlrm_config()); } else if (model_name == "split_test") { int batch_size = 8; return get_split_test_computation_graph(batch_size); @@ -145,6 +148,7 @@ int main(int argc, char **argv) { "inception_v3", "candle_uno", "bert", + "dlrm", "split_test", "single_operator"}; CLIArgumentKey key_model_name = cli_add_positional_argument( diff --git a/lib/compiler/test/src/compiler/series_parallel/get_computation_graph_series_parallel_decomposition.cc b/lib/compiler/test/src/compiler/series_parallel/get_computation_graph_series_parallel_decomposition.cc index 564cffaeb..2714c5b2e 100644 --- a/lib/compiler/test/src/compiler/series_parallel/get_computation_graph_series_parallel_decomposition.cc +++ b/lib/compiler/test/src/compiler/series_parallel/get_computation_graph_series_parallel_decomposition.cc @@ -1,6 +1,7 @@ #include "compiler/series_parallel/get_computation_graph_series_parallel_decomposition.h" #include "models/bert/bert.h" #include "models/candle_uno/candle_uno.h" +#include "models/dlrm/dlrm.h" #include "models/inception_v3/inception_v3.h" #include "models/split_test/split_test.h" #include "models/transformer/transformer.h" @@ -393,5 +394,13 @@ TEST_SUITE(FF_TEST_SUITE) { std::string result = render_preprocessed_computation_graph_for_sp_decomposition(cg); } + + SUBCASE("dlrm") { + ComputationGraph cg = + get_dlrm_computation_graph(get_default_dlrm_config()); + + std::string result = + render_preprocessed_computation_graph_for_sp_decomposition(cg); + } } } diff --git a/lib/models/src/models/dlrm/dlrm.cc b/lib/models/src/models/dlrm/dlrm.cc index 6b699993f..93846dbec 100644 --- a/lib/models/src/models/dlrm/dlrm.cc +++ b/lib/models/src/models/dlrm/dlrm.cc @@ -128,7 +128,9 @@ ComputationGraph get_dlrm_computation_graph(DLRMConfig const &config) { DataType::INT64)); tensor_guid_t dense_input = create_input_tensor( - {config.batch_size, config.mlp_bot.front()}, DataType::FLOAT); + {config.batch_size, config.mlp_bot.front()}, + DataType::HALF); // TODO: change this to DataType::FLOAT after cgb.cast is + // implemented. // Construct the model tensor_guid_t bottom_mlp_output = create_dlrm_mlp(