diff --git a/.github/workflows/per-lib-check.yml b/.github/workflows/per-lib-check.yml index a5ac6fd29..f976d369d 100644 --- a/.github/workflows/per-lib-check.yml +++ b/.github/workflows/per-lib-check.yml @@ -23,6 +23,9 @@ jobs: - name: Add helpers directory to path run: echo "${PWD}/.github/workflows/helpers" >> $GITHUB_PATH + - name: Free additional space on runner + run: free_space_on_runner.sh + - name: Install nix uses: cachix/install-nix-action@v25 with: diff --git a/bin/export-model-arch/src/export_model_arch.cc b/bin/export-model-arch/src/export_model_arch.cc index 98b7a003c..1c2dfd6ea 100644 --- a/bin/export-model-arch/src/export_model_arch.cc +++ b/bin/export-model-arch/src/export_model_arch.cc @@ -1,6 +1,8 @@ #include "compiler/series_parallel/computation_graph_binary_sp_decomposition.h" #include "compiler/series_parallel/get_computation_graph_series_parallel_decomposition.h" #include "export_model_arch/json_sp_model_export.dtg.h" +#include "models/bert/bert.h" +#include "models/candle_uno/candle_uno.h" #include "models/inception_v3/inception_v3.h" #include "models/split_test/split_test.h" #include "models/transformer/transformer.h" @@ -63,6 +65,10 @@ tl::expected } else if (model_name == "inception_v3") { return get_inception_v3_computation_graph( get_default_inception_v3_training_config()); + } else if (model_name == "candle_uno") { + return get_candle_uno_computation_graph(get_default_candle_uno_config()); + } else if (model_name == "bert") { + return get_bert_computation_graph(get_default_bert_config()); } else if (model_name == "split_test") { int batch_size = 8; return get_split_test_computation_graph(batch_size); @@ -135,8 +141,12 @@ int main(int argc, char **argv) { "output a dot representation of model's computation graph " "for preprocessed to help check series-parallel structure"}); - std::vector model_options = { - "transformer", "inception_v3", "split_test", "single_operator"}; + std::vector model_options = {"transformer", + "inception_v3", + "candle_uno", + "bert", + "split_test", + "single_operator"}; CLIArgumentKey key_model_name = cli_add_positional_argument( cli, CLIPositionalArgumentSpec{ diff --git a/lib/compiler/test/src/compiler/series_parallel/get_computation_graph_series_parallel_decomposition.cc b/lib/compiler/test/src/compiler/series_parallel/get_computation_graph_series_parallel_decomposition.cc index c9d84a894..564cffaeb 100644 --- a/lib/compiler/test/src/compiler/series_parallel/get_computation_graph_series_parallel_decomposition.cc +++ b/lib/compiler/test/src/compiler/series_parallel/get_computation_graph_series_parallel_decomposition.cc @@ -1,4 +1,6 @@ #include "compiler/series_parallel/get_computation_graph_series_parallel_decomposition.h" +#include "models/bert/bert.h" +#include "models/candle_uno/candle_uno.h" #include "models/inception_v3/inception_v3.h" #include "models/split_test/split_test.h" #include "models/transformer/transformer.h" @@ -302,6 +304,26 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(sp_decomposition.has_value()); } + + SUBCASE("candle_uno") { + ComputationGraph cg = + get_candle_uno_computation_graph(get_default_candle_uno_config()); + + std::optional sp_decomposition = + get_computation_graph_series_parallel_decomposition(cg); + + CHECK(sp_decomposition.has_value()); + } + + SUBCASE("bert") { + ComputationGraph cg = + get_bert_computation_graph(get_default_bert_config()); + + std::optional sp_decomposition = + get_computation_graph_series_parallel_decomposition(cg); + + CHECK(sp_decomposition.has_value()); + } } } @@ -347,5 +369,29 @@ TEST_SUITE(FF_TEST_SUITE) { std::string result = render_preprocessed_computation_graph_for_sp_decomposition(cg); } + + SUBCASE("inception_v3") { + ComputationGraph cg = get_inception_v3_computation_graph( + get_default_inception_v3_training_config()); + + std::string result = + render_preprocessed_computation_graph_for_sp_decomposition(cg); + } + + SUBCASE("candle_uno") { + ComputationGraph cg = + get_candle_uno_computation_graph(get_default_candle_uno_config()); + + std::string result = + render_preprocessed_computation_graph_for_sp_decomposition(cg); + } + + SUBCASE("bert") { + ComputationGraph cg = + get_bert_computation_graph(get_default_bert_config()); + + std::string result = + render_preprocessed_computation_graph_for_sp_decomposition(cg); + } } } diff --git a/lib/models/include/models/bert/bert.h b/lib/models/include/models/bert/bert.h new file mode 100644 index 000000000..0047996b7 --- /dev/null +++ b/lib/models/include/models/bert/bert.h @@ -0,0 +1,41 @@ +#ifndef _FLEXFLOW_LIB_MODELS_INCLUDE_MODELS_BERT_H +#define _FLEXFLOW_LIB_MODELS_INCLUDE_MODELS_BERT_H + +#include "models/bert/bert_config.dtg.h" +#include "pcg/computation_graph_builder.h" + +namespace FlexFlow { + +// Helper functions to construct the BERT model +tensor_guid_t create_bert_feedforward_network(ComputationGraphBuilder &, + BertConfig const &, + tensor_guid_t const &); +tensor_guid_t create_bert_encoder_layer(ComputationGraphBuilder &, + BertConfig const &, + tensor_guid_t const &); +tensor_guid_t create_bert_encoder(ComputationGraphBuilder &, + BertConfig const &, + tensor_guid_t const &); + +/** + * @brief Get the base config of the BERT model. + * + * @details Refer to + * https://huggingface.co/docs/transformers/v4.18.0/en/model_doc/bert#transformers.BertConfig + * for default configs. + */ +BertConfig get_default_bert_config(); + +/** + * @brief Get the BERT computation graph. + * + * @note This is a plain encoder-only model for pre-training. + * + * @param BertConfig The config of BERT model. + * @return ComputationGraph The computation graph of a BERT model. + */ +ComputationGraph get_bert_computation_graph(BertConfig const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/models/include/models/bert/bert_config.struct.toml b/lib/models/include/models/bert/bert_config.struct.toml new file mode 100644 index 000000000..398210cf4 --- /dev/null +++ b/lib/models/include/models/bert/bert_config.struct.toml @@ -0,0 +1,71 @@ +namespace = "FlexFlow" +name = "BertConfig" + +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/activation.dtg.h", +] + +[[fields]] +name = "vocab_size" +type = "size_t" + +[[fields]] +name = "hidden_size" +type = "size_t" + +[[fields]] +name = "num_encoder_layers" +type = "size_t" + +[[fields]] +name = "num_heads" +type = "size_t" + +[[fields]] +name = "dim_feedforward" +type = "size_t" + +[[fields]] +name = "hidden_act" +type = "::FlexFlow::Activation" + +[[fields]] +name = "hidden_dropout_prob" +type = "float" + +[[fields]] +name = "attention_probs_dropout_prob" +type = "float" + +[[fields]] +name = "initializer_range" +type = "float" + +[[fields]] +name = "layer_norm_eps" +type = "float" + +[[fields]] +name = "position_embedding_type" +type = "std::string" + +[[fields]] +name = "classifier_dropout" +type = "float" + +[[fields]] +name = "sequence_length" +type = "size_t" + +[[fields]] +name = "batch_size" +type = "size_t" diff --git a/lib/models/include/models/candle_uno/candle_uno.h b/lib/models/include/models/candle_uno/candle_uno.h new file mode 100644 index 000000000..a2d21f283 --- /dev/null +++ b/lib/models/include/models/candle_uno/candle_uno.h @@ -0,0 +1,41 @@ +#ifndef _FLEXFLOW_LIB_MODELS_INCLUDE_MODELS_CANDLE_UNO_H +#define _FLEXFLOW_LIB_MODELS_INCLUDE_MODELS_CANDLE_UNO_H + +#include "candle_uno_config.dtg.h" +#include "pcg/computation_graph_builder.h" +#include +#include +#include + +namespace FlexFlow { + +// Helper functions to construct the Candle Uno model +tensor_guid_t create_candle_uno_feature_model(ComputationGraphBuilder &, + CandleUnoConfig const &, + tensor_guid_t const &); + +/** + * @brief Get the default configs of Candle Uno model. + * + * @details The default configs come from the dataset used by the original + * model: + * https://github.com/ECP-CANDLE/Benchmarks/tree/f6a3da8818308c9edcd9fedbcf831dd5968efcdd/Pilot1/Uno + */ +CandleUnoConfig get_default_candle_uno_config(); + +/** + * @brief Get the Candle Uno computation graph. + * + * @details CandleUnoConfig.feature_shapes is a map from feature name to the + * number of channels for the feature, and CandleUnoConfig.input_features is a + * map from specific data identifier in the dataset to the feature name used in + * this model. + * + * @param CandleUnoConfig The config of the Candle Uno model. + * @return ComputationGraph The PCG of a Transformer model. + */ +ComputationGraph get_candle_uno_computation_graph(CandleUnoConfig const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/models/include/models/candle_uno/candle_uno_config.struct.toml b/lib/models/include/models/candle_uno/candle_uno_config.struct.toml new file mode 100644 index 000000000..667a6531c --- /dev/null +++ b/lib/models/include/models/candle_uno/candle_uno_config.struct.toml @@ -0,0 +1,52 @@ +namespace = "FlexFlow" +name = "CandleUnoConfig" + +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "", + "", + "", +] + +src_includes = [ + "utils/fmt/vector.h", + "utils/fmt/map.h", + "utils/hash/vector.h", + "utils/hash/map.h", +] + +[[fields]] +name = "batch_size" +type = "size_t" + +[[fields]] +name = "dense_layers" +type = "std::vector" + +[[fields]] +name = "dense_feature_layers" +type = "std::vector" + +[[fields]] +name = "feature_shapes" +type = "std::map" + +[[fields]] +name = "input_features" +type = "std::map" + +[[fields]] +name = "dropout" +type = "float" + +[[fields]] +name = "residual" +type = "bool" diff --git a/lib/models/src/models/bert/bert.cc b/lib/models/src/models/bert/bert.cc new file mode 100644 index 000000000..cf48f2399 --- /dev/null +++ b/lib/models/src/models/bert/bert.cc @@ -0,0 +1,160 @@ +#include "models/bert/bert.h" +#include "op-attrs/tensor_shape.h" +#include "pcg/computation_graph.h" +#include "pcg/initializers/truncated_normal_initializer_attrs.dtg.h" + +namespace FlexFlow { + +BertConfig get_default_bert_config() { + return BertConfig{/*vocab_size=*/30522, + /*hidden_size=*/768, + /*num_encoder_layers=*/12, + /*num_heads=*/12, + /*dim_feedforward=*/3072, + /*hidden_act=*/Activation::GELU, + /*hidden_dropout_prob=*/0.1, + /*attention_probs_dropout_prob=*/0.1, + /*initializer_range=*/0.02, + /*layer_norm_eps=*/1e-12, + /*position_embedding_type=*/"absolute", + /*classifier_dropout=*/0.1, + /*sequence_length=*/512, + /*batch_size=*/64}; +} + +tensor_guid_t + create_feedforward_network(ComputationGraphBuilder &cgb, + BertConfig const &config, + tensor_guid_t const &input, + InitializerAttrs const &bias_initializer, + InitializerAttrs const &projection_initializer) { + tensor_guid_t layer1_out = + cgb.dense(input, + config.dim_feedforward, + /*activation=*/config.hidden_act, + /*use_bias=*/true, + /*data_type=*/DataType::FLOAT, + /*projection_initializer=*/projection_initializer, + /*bias_initializer=*/bias_initializer); + tensor_guid_t dropout_out = + cgb.dropout(layer1_out, config.hidden_dropout_prob); + tensor_guid_t layer2_out = + cgb.dense(dropout_out, + config.hidden_size, + /*activation=*/std::nullopt, + /*use_bias=*/true, + /*data_type=*/DataType::FLOAT, + /*projection_initializer=*/projection_initializer, + /*bias_initializer=*/bias_initializer); + return cgb.dropout(layer2_out, config.hidden_dropout_prob); +}; + +tensor_guid_t + create_bert_encoder_layer(ComputationGraphBuilder &cgb, + BertConfig const &config, + tensor_guid_t const &input, + InitializerAttrs const &bias_initializer, + InitializerAttrs const &projection_initializer) { + assert(num_dims(cgb.get_shape(input)) == 3); + std::vector layer_norm_axis = {2}; // Apply layernorm across the last dim + int kdim = config.dim_feedforward / config.num_heads; + int vdim = config.dim_feedforward / config.num_heads; + tensor_guid_t self_attention = + cgb.multihead_attention(input, + input, + input, + config.hidden_size, + config.num_heads, + kdim, + vdim, + /*dropout=*/config.attention_probs_dropout_prob, + /*bias=*/true, + /*add_bias_kv=*/false, + /*add_zero_attn=*/false, + /*initializer=*/projection_initializer); + assert(are_tensor_guid_shapes_equivalent( + cgb.computation_graph, input, self_attention)); + + tensor_guid_t normalized = cgb.layer_norm(cgb.add(self_attention, input), + layer_norm_axis, + /*elementwise_affine=*/true, + config.layer_norm_eps); + assert(are_tensor_guid_shapes_equivalent( + cgb.computation_graph, input, normalized)); + + tensor_guid_t feedforward_output = create_feedforward_network( + cgb, config, normalized, bias_initializer, projection_initializer); + assert(are_tensor_guid_shapes_equivalent( + cgb.computation_graph, input, feedforward_output)); + return cgb.layer_norm(cgb.add(normalized, feedforward_output), + layer_norm_axis, + /*elementwise_affine=*/true, + config.layer_norm_eps); +} + +tensor_guid_t + create_bert_encoder(ComputationGraphBuilder &cgb, + BertConfig const &config, + tensor_guid_t const &input, + InitializerAttrs const &bias_initializer, + InitializerAttrs const &projection_initializer) { + tensor_guid_t t = input; + for (int i = 0; i < config.num_encoder_layers; i++) { + t = create_bert_encoder_layer( + cgb, config, t, bias_initializer, projection_initializer); + } + return t; +}; + +ComputationGraph get_bert_computation_graph(BertConfig const &config) { + if (config.position_embedding_type != "absolute") { + throw mk_runtime_error( + fmt::format("Currently only position_embedding_type=absolute is " + "supported, but found position_embedding_type={}. " + "If you need support for additional " + "position_embedding_type values, please create an issue.", + config.position_embedding_type)); + } + + ComputationGraphBuilder cgb; + InitializerAttrs projection_initializer = + InitializerAttrs{TruncatedNormalInitializerAttrs{ + /*seed=*/0, + /*mean=*/0, + /*stddev=*/config.initializer_range, + /*min_cutoff=*/-2 * config.initializer_range, + /*max_cutoff=*/2 * config.initializer_range}}; + InitializerAttrs bias_initializer = InitializerAttrs{ZeroInitializerAttrs{}}; + + TensorShape input_shape = TensorShape{ + TensorDims{FFOrdered{ + config.batch_size, config.sequence_length, config.hidden_size}}, + DataType::FLOAT, + }; + tensor_guid_t input = cgb.create_input(input_shape, CreateGrad::YES); + + tensor_guid_t encoder_output = create_bert_encoder( + cgb, config, input, bias_initializer, projection_initializer); + assert(are_tensor_guid_shapes_equivalent( + cgb.computation_graph, input, encoder_output)); + + tensor_guid_t out_prob = + cgb.softmax(cgb.dense(encoder_output, + /*outDim=*/config.vocab_size, + /*activation=*/config.hidden_act, + /*use_bias=*/true, + /*data_type=*/DataType::FLOAT, + /*projection_initializer=*/projection_initializer, + /*bias_initializer=*/bias_initializer)); + assert( + (cgb.get_shape(out_prob) == + TensorShape{ + TensorDims{FFOrdered{ + config.batch_size, config.sequence_length, config.vocab_size}}, + DataType::FLOAT, + })); + + return cgb.computation_graph; +} + +} // namespace FlexFlow diff --git a/lib/models/src/models/candle_uno/candle_uno.cc b/lib/models/src/models/candle_uno/candle_uno.cc new file mode 100644 index 000000000..4d52d515f --- /dev/null +++ b/lib/models/src/models/candle_uno/candle_uno.cc @@ -0,0 +1,123 @@ +#include "models/candle_uno/candle_uno.h" +#include "pcg/initializers/glorot_normal_attrs.dtg.h" + +namespace FlexFlow { + +CandleUnoConfig get_default_candle_uno_config() { + CandleUnoConfig config{ + /*batch_size=*/64, + /*dense_layers=*/std::vector(4, 4192), + /*dense_feature_layers=*/std::vector(8, 4192), + /*feature_shapes=*/std::map{}, + /*input_features=*/std::map{}, + /*dropout=*/0.1, + /*residual=*/false}; + + config.feature_shapes["dose"] = 1; + config.feature_shapes["cell.rnaseq"] = 942; + config.feature_shapes["drug.descriptors"] = 5270; + config.feature_shapes["drug.fingerprints"] = 2048; + + config.input_features["dose1"] = "dose"; + config.input_features["dose2"] = "dose"; + config.input_features["cell.rnaseq"] = "cell.rnaseq"; + config.input_features["drug1.descriptors"] = "drug.descriptors"; + config.input_features["drug1.fingerprints"] = "drug.fingerprints"; + config.input_features["drug2.descriptors"] = "drug.descriptors"; + config.input_features["drug2.fingerprints"] = "drug.fingerprints"; + + return config; +} + +tensor_guid_t create_candle_uno_feature_model( + ComputationGraphBuilder &cgb, + CandleUnoConfig const &config, + tensor_guid_t const &input, + InitializerAttrs const &kernel_initializer) { + tensor_guid_t t = input; + for (int const dense_dim : config.dense_feature_layers) { + t = cgb.dense(t, + dense_dim, + Activation::RELU, + /*use_bias=*/false, + /*data_type=*/DataType::FLOAT, + /*kernel_initializer=*/kernel_initializer); + if (config.dropout > 0) { + t = cgb.dropout(t, config.dropout); + } + } + return t; +} + +ComputationGraph + get_candle_uno_computation_graph(CandleUnoConfig const &config) { + ComputationGraphBuilder cgb; + InitializerAttrs kernel_initializer = + InitializerAttrs{GlorotNormalAttrs{/*seed=*/0}}; + + auto create_input_tensor = + [&](FFOrdered const &dims) -> tensor_guid_t { + TensorShape input_shape = TensorShape{ + TensorDims{dims}, + DataType::FLOAT, + }; + return cgb.create_input(input_shape, CreateGrad::YES); + }; + + std::set input_models; + for (auto const &shape : config.feature_shapes) { + auto const &type = shape.first; + if (type.find(".") != std::string::npos) { + std::string base_type = type.substr(0, type.find(".")); + // The string parsing here is to match with original implementation at + // https://github.com/ECP-CANDLE/Benchmarks/blob/f6a3da8818308c9edcd9fedbcf831dd5968efcdd/Pilot1/Uno/uno_baseline_keras2.py#L178. + if (base_type == "cell" || base_type == "drug") { + input_models.insert(type); + } + } + } + + std::vector all_inputs; + std::vector encoded_inputs; + + for (auto const &input_feature : config.input_features) { + std::string const &feature_name = input_feature.second; + size_t shape = config.feature_shapes.at(feature_name); + tensor_guid_t input = create_input_tensor({config.batch_size, shape}); + all_inputs.push_back(input); + + if (contains(input_models, feature_name)) { + encoded_inputs.emplace_back(create_candle_uno_feature_model( + cgb, config, input, kernel_initializer)); + } else { + encoded_inputs.emplace_back(input); + } + } + + tensor_guid_t output = cgb.concat(encoded_inputs, /*axis=*/1); + for (int const &dense_layer_dim : config.dense_layers) { + tensor_guid_t residual_input = output; + output = cgb.dense(output, + dense_layer_dim, + Activation::RELU, + /*use_bias=*/false, + /*data_type=*/DataType::FLOAT, + /*kernel_initializer=*/kernel_initializer); + if (config.dropout > 0) { + output = cgb.dropout(output, config.dropout); + } + if (config.residual) { + output = cgb.add(output, residual_input); + } + } + output = cgb.dense(output, + /*outDim=*/1, + /*activation=*/std::nullopt, + /*use_bias=*/false, + /*data_type=*/DataType::FLOAT, + /*kernel_initializer=*/kernel_initializer); + + return cgb.computation_graph; +} + +} // namespace FlexFlow diff --git a/lib/models/test/src/models/bert/bert.cc b/lib/models/test/src/models/bert/bert.cc new file mode 100644 index 000000000..1defc3a1a --- /dev/null +++ b/lib/models/test/src/models/bert/bert.cc @@ -0,0 +1,33 @@ +#include "models/bert/bert.h" +#include "pcg/computation_graph.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_bert_computation_graph") { + + SUBCASE("default config") { + BertConfig config = get_default_bert_config(); + + ComputationGraph result = get_bert_computation_graph(config); + + SUBCASE("num layers") { + int result_num_layers = get_layers(result).size(); + int correct_num_layers = 245; + CHECK(result_num_layers == correct_num_layers); + } + } + + SUBCASE("throws on position_embedding_type != absolute as other values are " + "currently unsupported") { + BertConfig config = [] { + BertConfig c = get_default_bert_config(); + c.position_embedding_type = "relative_key"; + return c; + }(); + + CHECK_THROWS(get_bert_computation_graph(config)); + } + } +} diff --git a/lib/models/test/src/models/candle_uno/candle_uno.cc b/lib/models/test/src/models/candle_uno/candle_uno.cc new file mode 100644 index 000000000..e32c5b548 --- /dev/null +++ b/lib/models/test/src/models/candle_uno/candle_uno.cc @@ -0,0 +1,19 @@ +#include "models/candle_uno/candle_uno.h" +#include "pcg/computation_graph.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_candle_uno_computation_graph") { + CandleUnoConfig config = get_default_candle_uno_config(); + + ComputationGraph result = get_candle_uno_computation_graph(config); + + SUBCASE("num layers") { + int result_num_layers = get_layers(result).size(); + int correct_num_layers = 142; + CHECK(result_num_layers == correct_num_layers); + } + } +} diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_dims.h b/lib/op-attrs/include/op-attrs/parallel_tensor_dims.h index 7a89b4bd7..1b8361abf 100644 --- a/lib/op-attrs/include/op-attrs/parallel_tensor_dims.h +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_dims.h @@ -17,6 +17,16 @@ size_t num_shard_dims(ParallelTensorDims const &); ParallelTensorDimDegrees get_parallel_degrees(ParallelTensorDims const &); +ParallelTensorDims lift_to_parallel(TensorDims const &); +ParallelTensorDims + lift_to_parallel_with_degrees(TensorDims const &, + SumDegree const &, + DiscardCopyDegree const &, + FFOrdered const &shard_degrees); +ParallelTensorDims + lift_to_parallel_with_degrees(TensorDims const &, + ParallelTensorDimDegrees const &); + int total_replica_degree(ParallelTensorDims const &); int total_shard_degree(ParallelTensorDims const &); int total_parallel_degree(ParallelTensorDims const &); diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h index 806a5f0de..a03151160 100644 --- a/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h @@ -23,8 +23,8 @@ ParallelTensorDimDegrees get_parallel_degrees(ParallelTensorShape const &); ParallelTensorShape lift_to_parallel(TensorShape const &); ParallelTensorShape lift_to_parallel_with_degrees(TensorShape const &, - SumDegree sum_degree, - DiscardCopyDegree discard_copy_degree, + SumDegree const &, + DiscardCopyDegree const &, FFOrdered const &shard_degrees); ParallelTensorShape lift_to_parallel_with_degrees(TensorShape const &, diff --git a/lib/op-attrs/include/op-attrs/tensor_dims.h b/lib/op-attrs/include/op-attrs/tensor_dims.h index c8af3b02e..ee44a3917 100644 --- a/lib/op-attrs/include/op-attrs/tensor_dims.h +++ b/lib/op-attrs/include/op-attrs/tensor_dims.h @@ -17,13 +17,6 @@ bool tensor_dims_is_broadcastable_to(TensorDims const &curr, std::optional get_broadcast_target_dims(std::unordered_set const &); -ParallelTensorDims lift_to_parallel(TensorDims const &); -ParallelTensorDims - lift_to_parallel_with_degrees(TensorDims const &, - SumDegree sum_degree, - DiscardCopyDegree discard_copy_degree, - FFOrdered const &shard_degrees); - } // namespace FlexFlow #endif diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc index 61062b84b..295554556 100644 --- a/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc @@ -1,8 +1,10 @@ #include "op-attrs/parallel_tensor_dims.h" #include "op-attrs/dim_ordered/transform.h" +#include "op-attrs/dim_ordered/zip.h" #include "op-attrs/replica_parallel_dim.h" #include "op-attrs/replica_parallel_dim_set.h" #include "op-attrs/shard_parallel_dim.h" +#include "op-attrs/tensor_dims.h" #include "utils/containers/all_of.h" #include "utils/containers/product.h" #include "utils/containers/transform.h" @@ -37,6 +39,42 @@ ParallelTensorDimDegrees get_parallel_degrees(ParallelTensorDims const &d) { }; } +ParallelTensorDims lift_to_parallel(TensorDims const &dims) { + std::vector shard_degrees(num_dims(dims), + 1); // 1 repeated num_dims(dims) times + return lift_to_parallel_with_degrees( + dims, SumDegree{1}, DiscardCopyDegree{1}, shard_degrees); +} + +ParallelTensorDims + lift_to_parallel_with_degrees(TensorDims const &unpar, + SumDegree const &sum_degree, + DiscardCopyDegree const &discard_copy_degree, + FFOrdered const &shard_degrees) { + std::vector lifted = + transform(zip(vector_of(unpar.ff_ordered), vector_of(shard_degrees)), + [](std::pair const &p) { + size_t size = p.first; + int degree = p.second; + return ShardParallelDim{size, degree}; + }); + + return ParallelTensorDims{FFOrdered{lifted}, + ReplicaParallelDimSet{ + sum_degree, + discard_copy_degree, + }}; +} + +ParallelTensorDims + lift_to_parallel_with_degrees(TensorDims const &unpar, + ParallelTensorDimDegrees const °rees) { + return lift_to_parallel_with_degrees(unpar, + degrees.sum_degree, + degrees.discard_copy_degree, + degrees.shard_degrees); +} + int total_replica_degree(ParallelTensorDims const &dims) { return dims.replica_dims.discard_copy_degree.value * dims.replica_dims.sum_degree.value; diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc index 3cd0f47a5..0663795db 100644 --- a/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc @@ -68,24 +68,24 @@ ParallelTensorShape lift_to_parallel(TensorShape const &s) { } ParallelTensorShape - lift_to_parallel_with_degrees(TensorShape const &s, - SumDegree sum_degree, - DiscardCopyDegree discard_copy_degree, + lift_to_parallel_with_degrees(TensorShape const &unpar, + SumDegree const &sum_degree, + DiscardCopyDegree const &discard_copy_degree, FFOrdered const &shard_degrees) { return ParallelTensorShape{ lift_to_parallel_with_degrees( - s.dims, sum_degree, discard_copy_degree, shard_degrees), - s.data_type, + unpar.dims, sum_degree, discard_copy_degree, shard_degrees), + unpar.data_type, }; } ParallelTensorShape - lift_to_parallel_with_degrees(TensorShape const &s, + lift_to_parallel_with_degrees(TensorShape const &unpar, ParallelTensorDimDegrees const °rees) { - return lift_to_parallel_with_degrees(s, - degrees.sum_degree, - degrees.discard_copy_degree, - degrees.shard_degrees); + return ParallelTensorShape{ + lift_to_parallel_with_degrees(unpar.dims, degrees), + unpar.data_type, + }; } TensorShape require_not_parallel(ParallelTensorShape const &s) { diff --git a/lib/op-attrs/src/op-attrs/tensor_dims.cc b/lib/op-attrs/src/op-attrs/tensor_dims.cc index ba7d6e835..1bb050db5 100644 --- a/lib/op-attrs/src/op-attrs/tensor_dims.cc +++ b/lib/op-attrs/src/op-attrs/tensor_dims.cc @@ -59,31 +59,4 @@ std::optional return std::nullopt; } -ParallelTensorDims lift_to_parallel(TensorDims const &dims) { - std::vector shard_degrees(num_dims(dims), - 1); // 1 repeated num_dims(dims) times - return lift_to_parallel_with_degrees( - dims, SumDegree{1}, DiscardCopyDegree{1}, shard_degrees); -} - -ParallelTensorDims - lift_to_parallel_with_degrees(TensorDims const &dims, - SumDegree sum_degree, - DiscardCopyDegree discard_copy_degree, - FFOrdered const &shard_degrees) { - std::vector lifted = - transform(zip(vector_of(dims.ff_ordered), vector_of(shard_degrees)), - [](std::pair const &p) { - size_t size = p.first; - int degree = p.second; - return ShardParallelDim(size, degree); - }); - - return ParallelTensorDims{FFOrdered{lifted}, - ReplicaParallelDimSet{ - sum_degree, - discard_copy_degree, - }}; -} - } // namespace FlexFlow diff --git a/lib/op-attrs/test/src/op-attrs/ops/concat.cc b/lib/op-attrs/test/src/op-attrs/ops/concat.cc new file mode 100644 index 000000000..9e842c3eb --- /dev/null +++ b/lib/op-attrs/test/src/op-attrs/ops/concat.cc @@ -0,0 +1,331 @@ +#include "op-attrs/ops/concat.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "test/utils/doctest/fmt/expected.h" +#include "test/utils/doctest/fmt/optional.h" +#include "utils/expected.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_output_shape(ConcatAttrs, std::vector)") { + ConcatAttrs attrs = ConcatAttrs{ + ff_dim_t{1}, + }; + + SUBCASE("empty input shapes list passed") { + std::vector input_shapes = {}; + + std::optional result = + optional_from_expected(get_output_shape(attrs, input_shapes)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + + size_t dim0_size = 12; + size_t dim2_size = 20; + TensorShape input_shape1 = TensorShape{ + TensorDims{FFOrdered{ + dim0_size, + 14, + dim2_size, + }}, + DataType::FLOAT, + }; + + SUBCASE("single element input shapes list passed") { + std::vector input_shapes = {input_shape1}; + + std::optional result = + optional_from_expected(get_output_shape(attrs, input_shapes)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + + TensorShape input_shape2 = TensorShape{ + TensorDims{FFOrdered{ + dim0_size, + 16, + dim2_size, + }}, + DataType::FLOAT, + }; + + TensorShape input_shape3 = TensorShape{ + TensorDims{FFOrdered{dim0_size, 18, dim2_size}}, + DataType::FLOAT, + }; + + SUBCASE("input shapes do not shared the same num_dims") { + TensorShape mismatched_num_dims = TensorShape{ + TensorDims{FFOrdered{ + dim0_size, + 20, + dim2_size, + 1, + }}, + DataType::FLOAT, + }; + + std::vector input_shapes = { + input_shape1, input_shape2, input_shape3, mismatched_num_dims}; + + std::optional result = + optional_from_expected(get_output_shape(attrs, input_shapes)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + + SUBCASE("concat axis is out of bounds") { + attrs = ConcatAttrs{ + ff_dim_t{3}, + }; + + std::vector input_shapes = { + input_shape1, input_shape2, input_shape3}; + + std::optional result = + optional_from_expected(get_output_shape(attrs, input_shapes)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + + SUBCASE("input shapes are valid") { + std::vector input_shapes = { + input_shape1, input_shape2, input_shape3}; + + tl::expected result = + get_output_shape(attrs, input_shapes); + tl::expected correct = TensorShape{ + TensorDims{FFOrdered{ + dim0_size, + 14 + 16 + 18, + dim2_size, + }}, + DataType::FLOAT, + }; + + CHECK(result == correct); + } + } + + TEST_CASE("get_output_shape(ConcatAttrs, std::vector)") { + ConcatAttrs attrs = ConcatAttrs{ + ff_dim_t{1}, + }; + + size_t dim0_size = 12; + size_t dim2_size = 20; + + TensorShape input_shape1 = TensorShape{ + TensorDims{FFOrdered{ + dim0_size, + 14, + dim2_size, + }}, + DataType::FLOAT, + }; + + TensorShape input_shape2 = TensorShape{ + TensorDims{FFOrdered{ + dim0_size, + 16, + dim2_size, + }}, + DataType::FLOAT, + }; + + TensorShape input_shape3 = TensorShape{ + TensorDims{FFOrdered{dim0_size, 18, dim2_size}}, + DataType::FLOAT, + }; + + TensorShape output_shape = TensorShape{ + TensorDims{FFOrdered{dim0_size, 14 + 16 + 18, dim2_size}}, + DataType::FLOAT, + }; + + auto lift_input1 = + [&](SumDegree o_sum, DiscardCopyDegree o_eq, int o0, int o1, int o2) { + return lift_to_parallel_with_degrees( + input_shape1, o_sum, o_eq, FFOrdered{o0, o1, o2}); + }; + + auto lift_input2 = + [&](SumDegree o_sum, DiscardCopyDegree o_eq, int o0, int o1, int o2) { + return lift_to_parallel_with_degrees( + input_shape2, o_sum, o_eq, FFOrdered{o0, o1, o2}); + }; + + auto lift_input3 = + [&](SumDegree o_sum, DiscardCopyDegree o_eq, int o0, int o1, int o2) { + return lift_to_parallel_with_degrees( + input_shape3, o_sum, o_eq, FFOrdered{o0, o1, o2}); + }; + + auto lift_output = + [&](SumDegree o_sum, DiscardCopyDegree o_eq, int o0, int o1, int o2) { + return lift_to_parallel_with_degrees( + output_shape, o_sum, o_eq, FFOrdered{o0, o1, o2}); + }; + + SUBCASE("sum reduction parallelism") { + SUBCASE("matching") { + SumDegree sum_degree = SumDegree{2}; + + std::vector inputs = { + lift_input1(sum_degree, DiscardCopyDegree{1}, 1, 1, 1), + lift_input2(sum_degree, DiscardCopyDegree{1}, 1, 1, 1), + lift_input3(sum_degree, DiscardCopyDegree{1}, 1, 1, 1), + }; + + tl::expected result = + get_output_shape(attrs, inputs); + tl::expected correct = + lift_output(sum_degree, DiscardCopyDegree{1}, 1, 1, 1); + + CHECK(result == correct); + } + + SUBCASE("not matching") { + std::vector inputs = { + lift_input1(SumDegree{2}, DiscardCopyDegree{1}, 1, 1, 1), + lift_input2(SumDegree{4}, DiscardCopyDegree{1}, 1, 1, 1), + lift_input3(SumDegree{4}, DiscardCopyDegree{1}, 1, 1, 1), + }; + + std::optional result = + optional_from_expected(get_output_shape(attrs, inputs)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + } + + SUBCASE("discard copy reduction parallelism") { + SUBCASE("matching") { + DiscardCopyDegree discard_copy_degree = DiscardCopyDegree{2}; + + std::vector inputs = { + lift_input1(SumDegree{1}, discard_copy_degree, 1, 1, 1), + lift_input2(SumDegree{1}, discard_copy_degree, 1, 1, 1), + lift_input3(SumDegree{1}, discard_copy_degree, 1, 1, 1), + }; + + tl::expected result = + get_output_shape(attrs, inputs); + tl::expected correct = + lift_output(SumDegree{1}, discard_copy_degree, 1, 1, 1); + + CHECK(result == correct); + } + + SUBCASE("not matching") { + std::vector inputs = { + lift_input1(SumDegree{1}, DiscardCopyDegree{2}, 1, 1, 1), + lift_input2(SumDegree{1}, DiscardCopyDegree{2}, 1, 1, 1), + lift_input3(SumDegree{1}, DiscardCopyDegree{4}, 1, 1, 1), + }; + + std::optional result = + optional_from_expected(get_output_shape(attrs, inputs)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + } + + SUBCASE("parallelism in axis dim") { + SUBCASE("matching") { + int degree = 2; + + std::vector inputs = { + lift_input1(SumDegree{1}, DiscardCopyDegree{1}, 1, degree, 1), + lift_input2(SumDegree{1}, DiscardCopyDegree{1}, 1, degree, 1), + lift_input3(SumDegree{1}, DiscardCopyDegree{1}, 1, degree, 1), + }; + + std::optional result = + optional_from_expected(get_output_shape(attrs, inputs)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + + SUBCASE("not matching") { + std::vector inputs = { + lift_input1(SumDegree{1}, DiscardCopyDegree{1}, 1, 1, 1), + lift_input2(SumDegree{1}, DiscardCopyDegree{1}, 1, 1, 1), + lift_input3(SumDegree{1}, DiscardCopyDegree{1}, 1, 2, 1), + }; + + std::optional result = + optional_from_expected(get_output_shape(attrs, inputs)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + } + + SUBCASE("parallelism in non-axis shard dims") { + SUBCASE("matching") { + int degree0 = 2; + int degree2 = 4; + + std::vector inputs = { + lift_input1( + SumDegree{1}, DiscardCopyDegree{1}, degree0, 1, degree2), + lift_input2( + SumDegree{1}, DiscardCopyDegree{1}, degree0, 1, degree2), + lift_input3( + SumDegree{1}, DiscardCopyDegree{1}, degree0, 1, degree2), + }; + + tl::expected result = + get_output_shape(attrs, inputs); + tl::expected correct = lift_output( + SumDegree{1}, DiscardCopyDegree{1}, degree0, 1, degree2); + + CHECK(result == correct); + } + + SUBCASE("not matching") { + std::vector inputs = { + lift_input1(SumDegree{1}, DiscardCopyDegree{1}, 2, 1, 4), + lift_input2(SumDegree{1}, DiscardCopyDegree{1}, 4, 1, 2), + lift_input3(SumDegree{1}, DiscardCopyDegree{1}, 4, 1, 2), + }; + + std::optional result = + optional_from_expected(get_output_shape(attrs, inputs)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + } + + SUBCASE("parallelism degrees are not mutually exclusive") { + SumDegree sum_degree = SumDegree{3}; + DiscardCopyDegree discard_copy_degree = DiscardCopyDegree{5}; + int degree0 = 2; + int degree2 = 4; + + std::vector inputs = { + lift_input1(sum_degree, discard_copy_degree, degree0, 1, degree2), + lift_input2(sum_degree, discard_copy_degree, degree0, 1, degree2), + lift_input3(sum_degree, discard_copy_degree, degree0, 1, degree2), + }; + + tl::expected result = + get_output_shape(attrs, inputs); + tl::expected correct = + lift_output(sum_degree, discard_copy_degree, degree0, 1, degree2); + + CHECK(result == correct); + } + } +} diff --git a/lib/pcg/include/pcg/initializer_attrs.variant.toml b/lib/pcg/include/pcg/initializer_attrs.variant.toml index 1ea9ce05a..2e878c5c5 100644 --- a/lib/pcg/include/pcg/initializer_attrs.variant.toml +++ b/lib/pcg/include/pcg/initializer_attrs.variant.toml @@ -11,9 +11,11 @@ features = [ includes = [ "pcg/initializers/glorot_uniform_attrs.dtg.h", + "pcg/initializers/glorot_normal_attrs.dtg.h", "pcg/initializers/zero_initializer_attrs.dtg.h", "pcg/initializers/uniform_initializer_attrs.h", "pcg/initializers/norm_initializer_attrs.dtg.h", + "pcg/initializers/truncated_normal_initializer_attrs.dtg.h", "pcg/initializers/constant_initializer_attrs.dtg.h", ] @@ -21,6 +23,10 @@ includes = [ type = "::FlexFlow::GlorotUniformAttrs" key = "glorot_uniform" +[[values]] +type = "::FlexFlow::GlorotNormalAttrs" +key = "glorot_normal" + [[values]] type = "::FlexFlow::ZeroInitializerAttrs" key = "zero" @@ -33,6 +39,10 @@ key = "uniform" type = "::FlexFlow::NormInitializerAttrs" key = "normal" +[[values]] +type = "::FlexFlow::TruncatedNormalInitializerAttrs" +key = "normal" + [[values]] type = "::FlexFlow::ConstantInitializerAttrs" key = "constant" diff --git a/lib/pcg/include/pcg/initializers/glorot_normal_attrs.struct.toml b/lib/pcg/include/pcg/initializers/glorot_normal_attrs.struct.toml new file mode 100644 index 000000000..fd0d8eb9b --- /dev/null +++ b/lib/pcg/include/pcg/initializers/glorot_normal_attrs.struct.toml @@ -0,0 +1,14 @@ +namespace = "FlexFlow" +name = "GlorotNormalAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[fields]] +name = "seed" +type = "int" diff --git a/lib/pcg/include/pcg/initializers/truncated_normal_initializer_attrs.struct.toml b/lib/pcg/include/pcg/initializers/truncated_normal_initializer_attrs.struct.toml new file mode 100644 index 000000000..9e4ec0272 --- /dev/null +++ b/lib/pcg/include/pcg/initializers/truncated_normal_initializer_attrs.struct.toml @@ -0,0 +1,30 @@ +namespace = "FlexFlow" +name = "TruncatedNormalInitializerAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[fields]] +name = "seed" +type = "int" + +[[fields]] +name = "mean" +type = "float" + +[[fields]] +name = "stddev" +type = "float" + +[[fields]] +name = "min_cutoff" +type = "float" + +[[fields]] +name = "max_cutoff" +type = "float" diff --git a/lib/utils/include/utils/containers.decl.h b/lib/utils/include/utils/containers.decl.h index 20ab6ce44..cb652a9e6 100644 --- a/lib/utils/include/utils/containers.decl.h +++ b/lib/utils/include/utils/containers.decl.h @@ -11,9 +11,6 @@ namespace FlexFlow { -template -Element sum(Container const &container); - template @@ -68,9 +65,6 @@ std::optional maybe_get_only(C const &c); template std::optional optional_all_of(Container const &, Function const &); -template -bool are_all_same(C const &c); - template std::function compare_by(F const &f); diff --git a/lib/utils/include/utils/containers.h b/lib/utils/include/utils/containers.h index f60ef77cd..6ac9eb10b 100644 --- a/lib/utils/include/utils/containers.h +++ b/lib/utils/include/utils/containers.h @@ -31,15 +31,6 @@ namespace FlexFlow { -template -Element sum(Container const &container) { - Element result = 0; - for (Element const &element : container) { - result += element; - } - return result; -} - template Element sum_where(Container const &container, ConditionF const &condition) { Element result = 0; @@ -135,17 +126,6 @@ std::optional optional_all_of(Container const &container, return true; } -template -bool are_all_same(C const &c) { - auto const &first = *c.cbegin(); - for (auto const &v : c) { - if (v != first) { - return false; - } - } - return true; -} - template std::vector flatmap(std::vector const &v, F const &f) { std::vector result;