Skip to content

Commit

Permalink
mapping error
Browse files Browse the repository at this point in the history
  • Loading branch information
lambda7xx committed Feb 9, 2024
1 parent a7ef500 commit a529b67
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 7 deletions.
2 changes: 1 addition & 1 deletion include/flexflow/mapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -354,4 +354,4 @@ class FFMapper : public NullMapper {
};

}; // namespace FlexFlow
#endif
#endif
2 changes: 1 addition & 1 deletion include/flexflow/ops/fused.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ class FusedOp : public Op {
MachineView const &pc,
CostMetrics &cost_metrics) const override;

void capture_graph(Task const *task,
static void capture_graph(Task const *task,
std::vector<PhysicalRegion> const &regions,
Context ctx,
Runtime *runtime, cudaGraph_t& graph, cudaGraphExec_t& instance);
Expand Down
2 changes: 1 addition & 1 deletion src/mapper/mapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1498,4 +1498,4 @@ void FFMapper::update_mappers(Machine machine,

FFMapper::~FFMapper(void) {}

}; // namespace FlexFlow
}; // namespace FlexFlow
13 changes: 9 additions & 4 deletions src/ops/fused.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1157,14 +1157,19 @@ __host__ void
Context ctx,
Runtime *runtime) {
// create new cuda graph
BatchConfig const *bc = BatchConfig::from_future(task->futures[0]);
cudaGraph_t graph;
cudaGraphExec_t instance;

FusedOpMeta *metas = *((FusedOpMeta **)task->local_args);
cudaStream_t stream;
checkCUDA(get_legion_stream(&stream));
FusedOp const *fused = metas->fused_op;
std::tuple<int, int, bool> graph_params =
std::make_tuple(bc->num_active_requests(),
bc->num_active_tokens(),
bc->num_generation_tokens > 0);
int scenario = 0;
cudaEvent_t t_start_update, t_end_update;
int shard_id = task->index_point.point_data[0];
auto it = metas->graph_collections.find(graph_params);
if(it != metas->graph_collections.end()) {
Expand Down Expand Up @@ -1202,7 +1207,7 @@ __host__ void
cudaEventCreate(&t_end_instantiate);
cudaEventRecord(t_start_instantiate, stream);

capture_graph(task, regions, ctx, runtime, graph, instance, metas, fused, stream, graph_params, scenario);
capture_graph(task, regions, ctx, runtime, graph, instance);
cudaGraphInstantiate(&instance, graph, NULL, NULL, 0);

cudaEventRecord(t_end_instantiate, stream);
Expand All @@ -1220,7 +1225,7 @@ __host__ void
cudaEventCreate(&t_end_instantiate);
cudaEventRecord(t_start_instantiate, stream);

capture_graph(task, regions, ctx, runtime, graph, instance, metas, fused, stream, graph_params, scenario);
capture_graph(task, regions, ctx, runtime, graph, instance);
cudaGraphInstantiate(&instance, graph, NULL, NULL, 0);

cudaEventRecord(t_end_instantiate, stream);
Expand All @@ -1236,7 +1241,7 @@ __host__ void
assert(metas->graph_collections.find(graph_params) !=
metas->graph_collections.end());
cudaGraphDestroy(graph);
printf("[%d]FUSED_OP.SCENARIO: %d, %d\n", shard_id, scenario, fused->numOperators);
// printf("[%d]FUSED_OP.SCENARIO: %d, %d\n", shard_id, scenario, fused->numOperators);

cudaEvent_t t_start_launch, t_end_launch;
cudaEventCreate(&t_start_launch);
Expand Down

0 comments on commit a529b67

Please sign in to comment.