Skip to content
This repository has been archived by the owner on Feb 7, 2025. It is now read-only.

Commit

Permalink
Optimize memory pool (#6)
Browse files Browse the repository at this point in the history
As the title stated.
  • Loading branch information
FFFrog authored Jul 16, 2024
1 parent fc79109 commit 89bd6b5
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 82 deletions.
36 changes: 0 additions & 36 deletions csrc/npu/NPUBlockHandle.h

This file was deleted.

50 changes: 5 additions & 45 deletions csrc/npu/NPUCachingAllocator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@
#include <c10/util/flat_hash_map.h>
#include <c10/util/irange.h>

#include "Memory.h"
#include "NPUBlockHandle.h"
#include "csrc/npu/CachingAllocatorHelper.h"
#include "csrc/npu/NPUCachingAllocator.h"
#include "npu/acl/include/acl/acl_base.h"
#include "npu/core/NPUGuard.h"
#include "csrc/npu/NPUFunctions.h"
#include "csrc/npu/Memory.h"

#include "npu/core/interface/AsyncTaskQueueInterface.h"
#include "npu/core/sys_ctrl/npu_sys_ctrl.h"
#include "npu/acl/include/acl/acl_base.h"

namespace c10_npu {
namespace NPUCachingAllocator {
Expand Down Expand Up @@ -128,7 +128,7 @@ struct BlockPool {
};

struct Block {
int device; // npu
int device;
void* stream; // allocation stream
stream_set stream_uses; // streams on which the block was used
size_t size; // block size in bytes
Expand Down Expand Up @@ -2259,46 +2259,6 @@ void local_raw_delete(void* ptr) {
caching_allocator.free(ptr);
}

void* MallocBlock(size_t size, void* stream, int device) {
if (device == -1) {
NPU_CHECK_ERROR(c10_npu::GetDevice(&device));
}
if ((device < 0) ||
(device > static_cast<int>(caching_allocator.device_allocator.size()))) {
return nullptr;
}
AT_ASSERT(
caching_allocator.device_allocator[device],
PTA_ERROR(ErrCode::NOT_FOUND));
AT_ASSERT(stream, PTA_ERROR(ErrCode::NOT_FOUND));
auto block =
caching_allocator.device_allocator[device]->malloc(device, size, stream);
AT_ASSERT(block, PTA_ERROR(ErrCode::NOT_FOUND));
return reinterpret_cast<void*>(block);
}

void FreeBlock(void* handle) {
Block* block = reinterpret_cast<Block*>(handle);
AT_ASSERT(block, PTA_ERROR(ErrCode::PTR));
caching_allocator.assertValidDevice(block->device);
AT_ASSERT(
caching_allocator.device_allocator[block->device],
PTA_ERROR(ErrCode::NOT_FOUND));
caching_allocator.device_allocator[block->device]->free(block);
}

void* GetBlockPtr(const void* handle) {
const Block* block = reinterpret_cast<const Block*>(handle);
AT_ASSERT(block, PTA_ERROR(ErrCode::PTR));
return block->ptr;
}

size_t GetBlockSize(const void* handle) {
const Block* block = reinterpret_cast<const Block*>(handle);
AT_ASSERT(block, PTA_ERROR(ErrCode::PTR));
return block->size;
}

struct BackendStaticInitializer {
BackendStaticInitializer() {
allocator.store(&caching_allocator);
Expand Down
3 changes: 2 additions & 1 deletion csrc/npu/NPUCachingAllocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@ C10_NPU_API std::mutex* getFreeMutex();
// block inside of already allocated area.
class FreeMemoryCallback {
public:
virtual ~FreeMemoryCallback(){};
virtual ~FreeMemoryCallback() {};
virtual bool Execute() = 0;
};

C10_DECLARE_REGISTRY(FreeNPUMemoryCallbacksRegistry, FreeMemoryCallback);

#define REGISTER_FREE_MEMORY_CALLBACK(name, ...) \
C10_REGISTER_CLASS(FreeNPUMemoryCallbacksRegistry, name, __VA_ARGS__);

Expand Down

0 comments on commit 89bd6b5

Please sign in to comment.