Skip to content

Commit 2dce364

Browse files
desertfirepytorchmergebot
authored andcommitted
[AOTI][refactor] Remove model_container_runner_cuda.cpp (pytorch#116113)
Differential Revision: [D52301272](https://our.internmc.facebook.com/intern/diff/D52301272) Pull Request resolved: pytorch#116113 Approved by: https://github.com/khabinov ghstack dependencies: pytorch#116047
1 parent f71d302 commit 2dce364

File tree

4 files changed

+12
-19
lines changed

4 files changed

+12
-19
lines changed

build_variables.bzl

-1
Original file line numberDiff line numberDiff line change
@@ -652,7 +652,6 @@ libtorch_cuda_core_sources = [
652652
"torch/csrc/CudaIPCTypes.cpp",
653653
"torch/csrc/cuda/comm.cpp",
654654
"torch/csrc/cuda/memory_snapshot.cpp",
655-
"torch/csrc/inductor/aoti_runner/model_container_runner_cuda.cpp",
656655
"torch/csrc/inductor/aoti_torch/shim_cuda.cpp",
657656
"torch/csrc/jit/codegen/fuser/cuda/fused_kernel.cpp",
658657
"torch/csrc/profiler/stubs/cuda.cpp",

torch/csrc/inductor/aoti_runner/model_container_runner_cpu.h

+4
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@ class TORCH_API AOTIModelContainerRunnerCpu : public AOTIModelContainerRunner {
99
const std::string& model_so_path,
1010
size_t num_models = 1)
1111
: AOTIModelContainerRunner(model_so_path, num_models, true, "") {}
12+
13+
std::vector<at::Tensor> run(std::vector<at::Tensor>& inputs) {
14+
return AOTIModelContainerRunner::run(inputs);
15+
}
1216
};
1317

1418
} // namespace torch::inductor

torch/csrc/inductor/aoti_runner/model_container_runner_cuda.cpp

-16
This file was deleted.

torch/csrc/inductor/aoti_runner/model_container_runner_cuda.h

+8-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#pragma once
22

3-
#include <cuda_runtime_api.h>
3+
#include <c10/cuda/CUDAStream.h>
44
#include <torch/csrc/inductor/aoti_runner/model_container_runner.h>
55

66
namespace torch::inductor {
@@ -15,7 +15,13 @@ class TORCH_API AOTIModelContainerRunnerCuda : public AOTIModelContainerRunner {
1515

1616
std::vector<at::Tensor> run(
1717
std::vector<at::Tensor>& inputs,
18-
cudaStream_t cuda_stream_handle = nullptr);
18+
cudaStream_t cuda_stream_handle = nullptr) {
19+
if (cuda_stream_handle == nullptr) {
20+
cuda_stream_handle = c10::cuda::getCurrentCUDAStream().stream();
21+
}
22+
return AOTIModelContainerRunner::run(
23+
inputs, reinterpret_cast<AOTInductorStreamHandle>(cuda_stream_handle));
24+
}
1925
};
2026

2127
} // namespace torch::inductor

0 commit comments

Comments
 (0)