Skip to content

Commit

Permalink
Add DLRM to export-model-arch and series_parallel tests
Browse files Browse the repository at this point in the history
  • Loading branch information
hsdfzhsdfz committed Sep 24, 2024
1 parent 274f1f0 commit 96d6f69
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 1 deletion.
4 changes: 4 additions & 0 deletions bin/export-model-arch/src/export_model_arch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "export_model_arch/json_sp_model_export.dtg.h"
#include "models/bert/bert.h"
#include "models/candle_uno/candle_uno.h"
#include "models/dlrm/dlrm.h"
#include "models/inception_v3/inception_v3.h"
#include "models/split_test/split_test.h"
#include "models/transformer/transformer.h"
Expand Down Expand Up @@ -69,6 +70,8 @@ tl::expected<ComputationGraph, std::string>
return get_candle_uno_computation_graph(get_default_candle_uno_config());
} else if (model_name == "bert") {
return get_bert_computation_graph(get_default_bert_config());
} else if (model_name == "dlrm") {
return get_dlrm_computation_graph(get_default_dlrm_config());
} else if (model_name == "split_test") {
int batch_size = 8;
return get_split_test_computation_graph(batch_size);
Expand Down Expand Up @@ -145,6 +148,7 @@ int main(int argc, char **argv) {
"inception_v3",
"candle_uno",
"bert",
"dlrm",
"split_test",
"single_operator"};
CLIArgumentKey key_model_name = cli_add_positional_argument(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "compiler/series_parallel/get_computation_graph_series_parallel_decomposition.h"
#include "models/bert/bert.h"
#include "models/candle_uno/candle_uno.h"
#include "models/dlrm/dlrm.h"
#include "models/inception_v3/inception_v3.h"
#include "models/split_test/split_test.h"
#include "models/transformer/transformer.h"
Expand Down Expand Up @@ -393,5 +394,13 @@ TEST_SUITE(FF_TEST_SUITE) {
std::string result =
render_preprocessed_computation_graph_for_sp_decomposition(cg);
}

SUBCASE("dlrm") {
ComputationGraph cg =
get_dlrm_computation_graph(get_default_dlrm_config());

std::string result =
render_preprocessed_computation_graph_for_sp_decomposition(cg);
}
}
}
4 changes: 3 additions & 1 deletion lib/models/src/models/dlrm/dlrm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,9 @@ ComputationGraph get_dlrm_computation_graph(DLRMConfig const &config) {
DataType::INT64));

tensor_guid_t dense_input = create_input_tensor(
{config.batch_size, config.mlp_bot.front()}, DataType::FLOAT);
{config.batch_size, config.mlp_bot.front()},
DataType::HALF); // TODO: change this to DataType::FLOAT after cgb.cast is
// implemented.

// Construct the model
tensor_guid_t bottom_mlp_output = create_dlrm_mlp(
Expand Down

0 comments on commit 96d6f69

Please sign in to comment.