Skip to content

Commit

Permalink
Add BERT model computation graph (#1488)
Browse files Browse the repository at this point in the history
* Add initial bert model structure

* Update following review

* Rename config

* Add additional bert configs

* Update based on reviewing

* Added assert checks

* Add error message for unsupported BertConfig.position_embedding_type

* Format

* fix typo

* Add bert to export-model-arch

* Format

---------

Co-authored-by: Colin Unger <lockshaw@lockshaw.net>
  • Loading branch information
hsdfzhsdfz and lockshaw authored Sep 23, 2024
1 parent dbb642a commit cf96db6
Show file tree
Hide file tree
Showing 8 changed files with 379 additions and 0 deletions.
4 changes: 4 additions & 0 deletions bin/export-model-arch/src/export_model_arch.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "compiler/series_parallel/computation_graph_binary_sp_decomposition.h"
#include "compiler/series_parallel/get_computation_graph_series_parallel_decomposition.h"
#include "export_model_arch/json_sp_model_export.dtg.h"
#include "models/bert/bert.h"
#include "models/candle_uno/candle_uno.h"
#include "models/inception_v3/inception_v3.h"
#include "models/split_test/split_test.h"
Expand Down Expand Up @@ -66,6 +67,8 @@ tl::expected<ComputationGraph, std::string>
get_default_inception_v3_training_config());
} else if (model_name == "candle_uno") {
return get_candle_uno_computation_graph(get_default_candle_uno_config());
} else if (model_name == "bert") {
return get_bert_computation_graph(get_default_bert_config());
} else if (model_name == "split_test") {
int batch_size = 8;
return get_split_test_computation_graph(batch_size);
Expand Down Expand Up @@ -141,6 +144,7 @@ int main(int argc, char **argv) {
std::vector<std::string> model_options = {"transformer",
"inception_v3",
"candle_uno",
"bert",
"split_test",
"single_operator"};
CLIArgumentKey key_model_name = cli_add_positional_argument(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "compiler/series_parallel/get_computation_graph_series_parallel_decomposition.h"
#include "models/bert/bert.h"
#include "models/candle_uno/candle_uno.h"
#include "models/inception_v3/inception_v3.h"
#include "models/split_test/split_test.h"
Expand Down Expand Up @@ -313,6 +314,16 @@ TEST_SUITE(FF_TEST_SUITE) {

CHECK(sp_decomposition.has_value());
}

SUBCASE("bert") {
ComputationGraph cg =
get_bert_computation_graph(get_default_bert_config());

std::optional<SeriesParallelDecomposition> sp_decomposition =
get_computation_graph_series_parallel_decomposition(cg);

CHECK(sp_decomposition.has_value());
}
}
}

Expand Down Expand Up @@ -358,5 +369,29 @@ TEST_SUITE(FF_TEST_SUITE) {
std::string result =
render_preprocessed_computation_graph_for_sp_decomposition(cg);
}

SUBCASE("inception_v3") {
ComputationGraph cg = get_inception_v3_computation_graph(
get_default_inception_v3_training_config());

std::string result =
render_preprocessed_computation_graph_for_sp_decomposition(cg);
}

SUBCASE("candle_uno") {
ComputationGraph cg =
get_candle_uno_computation_graph(get_default_candle_uno_config());

std::string result =
render_preprocessed_computation_graph_for_sp_decomposition(cg);
}

SUBCASE("bert") {
ComputationGraph cg =
get_bert_computation_graph(get_default_bert_config());

std::string result =
render_preprocessed_computation_graph_for_sp_decomposition(cg);
}
}
}
41 changes: 41 additions & 0 deletions lib/models/include/models/bert/bert.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#ifndef _FLEXFLOW_LIB_MODELS_INCLUDE_MODELS_BERT_H
#define _FLEXFLOW_LIB_MODELS_INCLUDE_MODELS_BERT_H

#include "models/bert/bert_config.dtg.h"
#include "pcg/computation_graph_builder.h"

namespace FlexFlow {

// Helper functions to construct the BERT model
tensor_guid_t create_bert_feedforward_network(ComputationGraphBuilder &,
BertConfig const &,
tensor_guid_t const &);
tensor_guid_t create_bert_encoder_layer(ComputationGraphBuilder &,
BertConfig const &,
tensor_guid_t const &);
tensor_guid_t create_bert_encoder(ComputationGraphBuilder &,
BertConfig const &,
tensor_guid_t const &);

/**
* @brief Get the base config of the BERT model.
*
* @details Refer to
* https://huggingface.co/docs/transformers/v4.18.0/en/model_doc/bert#transformers.BertConfig
* for default configs.
*/
BertConfig get_default_bert_config();

/**
* @brief Get the BERT computation graph.
*
* @note This is a plain encoder-only model for pre-training.
*
* @param BertConfig The config of BERT model.
* @return ComputationGraph The computation graph of a BERT model.
*/
ComputationGraph get_bert_computation_graph(BertConfig const &);

} // namespace FlexFlow

#endif
71 changes: 71 additions & 0 deletions lib/models/include/models/bert/bert_config.struct.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
namespace = "FlexFlow"
name = "BertConfig"

features = [
"eq",
"ord",
"hash",
"json",
"rapidcheck",
"fmt",
]

includes = [
"op-attrs/activation.dtg.h",
]

[[fields]]
name = "vocab_size"
type = "size_t"

[[fields]]
name = "hidden_size"
type = "size_t"

[[fields]]
name = "num_encoder_layers"
type = "size_t"

[[fields]]
name = "num_heads"
type = "size_t"

[[fields]]
name = "dim_feedforward"
type = "size_t"

[[fields]]
name = "hidden_act"
type = "::FlexFlow::Activation"

[[fields]]
name = "hidden_dropout_prob"
type = "float"

[[fields]]
name = "attention_probs_dropout_prob"
type = "float"

[[fields]]
name = "initializer_range"
type = "float"

[[fields]]
name = "layer_norm_eps"
type = "float"

[[fields]]
name = "position_embedding_type"
type = "std::string"

[[fields]]
name = "classifier_dropout"
type = "float"

[[fields]]
name = "sequence_length"
type = "size_t"

[[fields]]
name = "batch_size"
type = "size_t"
160 changes: 160 additions & 0 deletions lib/models/src/models/bert/bert.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
#include "models/bert/bert.h"
#include "op-attrs/tensor_shape.h"
#include "pcg/computation_graph.h"
#include "pcg/initializers/truncated_normal_initializer_attrs.dtg.h"

namespace FlexFlow {

BertConfig get_default_bert_config() {
return BertConfig{/*vocab_size=*/30522,
/*hidden_size=*/768,
/*num_encoder_layers=*/12,
/*num_heads=*/12,
/*dim_feedforward=*/3072,
/*hidden_act=*/Activation::GELU,
/*hidden_dropout_prob=*/0.1,
/*attention_probs_dropout_prob=*/0.1,
/*initializer_range=*/0.02,
/*layer_norm_eps=*/1e-12,
/*position_embedding_type=*/"absolute",
/*classifier_dropout=*/0.1,
/*sequence_length=*/512,
/*batch_size=*/64};
}

tensor_guid_t
create_feedforward_network(ComputationGraphBuilder &cgb,
BertConfig const &config,
tensor_guid_t const &input,
InitializerAttrs const &bias_initializer,
InitializerAttrs const &projection_initializer) {
tensor_guid_t layer1_out =
cgb.dense(input,
config.dim_feedforward,
/*activation=*/config.hidden_act,
/*use_bias=*/true,
/*data_type=*/DataType::FLOAT,
/*projection_initializer=*/projection_initializer,
/*bias_initializer=*/bias_initializer);
tensor_guid_t dropout_out =
cgb.dropout(layer1_out, config.hidden_dropout_prob);
tensor_guid_t layer2_out =
cgb.dense(dropout_out,
config.hidden_size,
/*activation=*/std::nullopt,
/*use_bias=*/true,
/*data_type=*/DataType::FLOAT,
/*projection_initializer=*/projection_initializer,
/*bias_initializer=*/bias_initializer);
return cgb.dropout(layer2_out, config.hidden_dropout_prob);
};

tensor_guid_t
create_bert_encoder_layer(ComputationGraphBuilder &cgb,
BertConfig const &config,
tensor_guid_t const &input,
InitializerAttrs const &bias_initializer,
InitializerAttrs const &projection_initializer) {
assert(num_dims(cgb.get_shape(input)) == 3);
std::vector<int> layer_norm_axis = {2}; // Apply layernorm across the last dim
int kdim = config.dim_feedforward / config.num_heads;
int vdim = config.dim_feedforward / config.num_heads;
tensor_guid_t self_attention =
cgb.multihead_attention(input,
input,
input,
config.hidden_size,
config.num_heads,
kdim,
vdim,
/*dropout=*/config.attention_probs_dropout_prob,
/*bias=*/true,
/*add_bias_kv=*/false,
/*add_zero_attn=*/false,
/*initializer=*/projection_initializer);
assert(are_tensor_guid_shapes_equivalent(
cgb.computation_graph, input, self_attention));

tensor_guid_t normalized = cgb.layer_norm(cgb.add(self_attention, input),
layer_norm_axis,
/*elementwise_affine=*/true,
config.layer_norm_eps);
assert(are_tensor_guid_shapes_equivalent(
cgb.computation_graph, input, normalized));

tensor_guid_t feedforward_output = create_feedforward_network(
cgb, config, normalized, bias_initializer, projection_initializer);
assert(are_tensor_guid_shapes_equivalent(
cgb.computation_graph, input, feedforward_output));
return cgb.layer_norm(cgb.add(normalized, feedforward_output),
layer_norm_axis,
/*elementwise_affine=*/true,
config.layer_norm_eps);
}

tensor_guid_t
create_bert_encoder(ComputationGraphBuilder &cgb,
BertConfig const &config,
tensor_guid_t const &input,
InitializerAttrs const &bias_initializer,
InitializerAttrs const &projection_initializer) {
tensor_guid_t t = input;
for (int i = 0; i < config.num_encoder_layers; i++) {
t = create_bert_encoder_layer(
cgb, config, t, bias_initializer, projection_initializer);
}
return t;
};

ComputationGraph get_bert_computation_graph(BertConfig const &config) {
if (config.position_embedding_type != "absolute") {
throw mk_runtime_error(
fmt::format("Currently only position_embedding_type=absolute is "
"supported, but found position_embedding_type={}. "
"If you need support for additional "
"position_embedding_type values, please create an issue.",
config.position_embedding_type));
}

ComputationGraphBuilder cgb;
InitializerAttrs projection_initializer =
InitializerAttrs{TruncatedNormalInitializerAttrs{
/*seed=*/0,
/*mean=*/0,
/*stddev=*/config.initializer_range,
/*min_cutoff=*/-2 * config.initializer_range,
/*max_cutoff=*/2 * config.initializer_range}};
InitializerAttrs bias_initializer = InitializerAttrs{ZeroInitializerAttrs{}};

TensorShape input_shape = TensorShape{
TensorDims{FFOrdered<size_t>{
config.batch_size, config.sequence_length, config.hidden_size}},
DataType::FLOAT,
};
tensor_guid_t input = cgb.create_input(input_shape, CreateGrad::YES);

tensor_guid_t encoder_output = create_bert_encoder(
cgb, config, input, bias_initializer, projection_initializer);
assert(are_tensor_guid_shapes_equivalent(
cgb.computation_graph, input, encoder_output));

tensor_guid_t out_prob =
cgb.softmax(cgb.dense(encoder_output,
/*outDim=*/config.vocab_size,
/*activation=*/config.hidden_act,
/*use_bias=*/true,
/*data_type=*/DataType::FLOAT,
/*projection_initializer=*/projection_initializer,
/*bias_initializer=*/bias_initializer));
assert(
(cgb.get_shape(out_prob) ==
TensorShape{
TensorDims{FFOrdered<size_t>{
config.batch_size, config.sequence_length, config.vocab_size}},
DataType::FLOAT,
}));

return cgb.computation_graph;
}

} // namespace FlexFlow
33 changes: 33 additions & 0 deletions lib/models/test/src/models/bert/bert.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#include "models/bert/bert.h"
#include "pcg/computation_graph.h"
#include <doctest/doctest.h>

using namespace ::FlexFlow;

TEST_SUITE(FF_TEST_SUITE) {
TEST_CASE("get_bert_computation_graph") {

SUBCASE("default config") {
BertConfig config = get_default_bert_config();

ComputationGraph result = get_bert_computation_graph(config);

SUBCASE("num layers") {
int result_num_layers = get_layers(result).size();
int correct_num_layers = 245;
CHECK(result_num_layers == correct_num_layers);
}
}

SUBCASE("throws on position_embedding_type != absolute as other values are "
"currently unsupported") {
BertConfig config = [] {
BertConfig c = get_default_bert_config();
c.position_embedding_type = "relative_key";
return c;
}();

CHECK_THROWS(get_bert_computation_graph(config));
}
}
}
Loading

0 comments on commit cf96db6

Please sign in to comment.