Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fft cleanup #12

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
build/
ciderpress/lib/deps/
ciderpress/lib/pwutil/config.h
ciderpress/lib/fft_wrapper/cider_fft_config.h
dist/
docs/_build
docs/_static
Expand Down
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ recursive-include ciderpress/lib *.c *.h *.h.in CMakeLists.txt
recursive-exclude ciderpress/lib *.cl

global-exclude *.py[cod]
prune ciderpress/lib/build

# docs
recursive-exclude docs/
42 changes: 42 additions & 0 deletions ciderpress/gpaw/tests/test_gpaw_grids.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import unittest

import numpy as np
from numpy.testing import assert_allclose

from ciderpress.gpaw.gpaw_grids import SBTFullGridDescriptor

LMAX = 6


def gaussian(l, alpha, r):
return r**l * np.exp(-alpha * r * r)


def gaussian_ft(l, alpha, k):
const = np.pi**1.5 / 2**l / alpha ** (1.5 + l)
return const * k**l * np.exp((-0.25 / alpha) * k * k)


class TestSBT(unittest.TestCase):
def test_sbt_gaussian(self):
sbtgd = SBTFullGridDescriptor(0.001, 1e12, 0.02, N=1024, lmax=LMAX)
for l in range(LMAX + 1):
for alpha in [0.01, 0.1, 1.0, 10.0, 100.0]:
scale = alpha ** (1.5 + 0.5 * l)
f_g = scale * gaussian(l, alpha, sbtgd.r_g)
fref_k = scale * gaussian_ft(l, alpha, sbtgd.k_g)
ftest_k = 4 * np.pi * sbtgd.transform_single_fwd(f_g, l)
print("MINMAX", np.min(sbtgd.k_g), np.max(sbtgd.k_g))
print(
np.linalg.norm(ftest_k - fref_k),
np.linalg.norm(fref_k),
np.linalg.norm(ftest_k),
)
print(np.max(np.abs(ftest_k - fref_k)))
assert_allclose(ftest_k, fref_k, atol=1e-4, rtol=0)
ftest_g = (0.25 / np.pi) * sbtgd.transform_single_bwd(ftest_k, l)
assert_allclose(ftest_g, f_g, atol=1e-4, rtol=0)


if __name__ == "__main__":
unittest.main()
127 changes: 80 additions & 47 deletions ciderpress/lib/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@


cmake_minimum_required (VERSION 3.5)
cmake_minimum_required (VERSION 3.20)
project (ciderpress)

if (NOT CMAKE_BUILD_TYPE)
Expand All @@ -18,7 +18,10 @@ set(CMAKE_C_FLAGS_RELEASE "-g -O3")
list( APPEND CMAKE_BUILD_RPATH ${CMAKE_PREFIX_PATH}/lib )
set(CMAKE_INSTALL_RPATH_USE_LINK_PATH TRUE)

option(BUILD_WITH_MKL "use MKL for BLAS and FFT" off)
option(BUILD_FFTW "Build fftw3" on)
option(BUILD_MARCH_NATIVE "gcc flag -march=native" off)
set(ASSUME_MPI_VENDOR "openmpi")
if (BUILD_MARCH_NATIVE)
include(CheckCCompilerFlag)
CHECK_C_COMPILER_FLAG("-march=native" COMPILER_SUPPORTS_MARCH_NATIVE)
Expand Down Expand Up @@ -51,49 +54,89 @@ else ()
set(OpenMP_C_FLAGS " ")
endif()

# We want MKL to use the same threading as the rest of the project
set(MKL_INTERFACE lp64)
if("iomp5" IN_LIST OpenMP_C_LIB_NAMES)
set(MKL_THREADING intel_thread)
elseif("gomp" IN_LIST OpenMP_C_LIB_NAMES)
set(MKL_THREADING gnu_thread)
else()
set(MKL_THREADING sequential)
endif()
find_package(MKL CONFIG REQUIRED PATHS $ENV{MKLROOT})
set(MKL_INCLUDE_DIR ${MKL_ROOT}/include)
message(STATUS "Imported MKL targets: ${MKL_IMPORTED_TARGETS}")
message(STATUS "MKL include path is: ${MKL_INCLUDE_DIR}")
if (NOT MKL_FOUND)
message(FATAL_ERROR "MKL not found")
endif()

if (NOT BLAS_LIBRARIES)
find_package(BLAS)
check_function_exists(ffsll HAVE_FFS)
endif()

find_package(MPI)
if (MPI_LIBRARIES)
set(HAVE_MPI 1)
set(FFTW_CONFIGURE_ARGS --enable-static=no --enable-shared=yes --enable-threads --enable-mpi=yes --enable-openmp MPILIBS=-lmpi)
message(STATUS "Found MPI: ${FFTW_CONFIGURE_ARGS}")
else ()
set(HAVE_MPI 0)
set(FFTW_CONFIGURE_ARGS --enable-static=no --enable-shared=yes --enable-threads --enable-openmp)
message(STATUS "Did not find MPI: ${FFTW_CONFIGURE_ARGS}")
endif()
# include_directories(${CMAKE_PYTHON_INCLUDE_PATH})
# link_directories(${CMAKE_PYTHON_LIBRARY_PATH})

find_package(Python REQUIRED COMPONENTS Interpreter Development)

if (NOT BLAS_LIBRARIES)
message(FATAL_ERROR "A required library with BLAS API not found.")
if (BUILD_WITH_MKL)
# Use MKL for the BLAS and FFT libraries
# We want MKL to use the same threading as the rest of the project
set(MKL_INTERFACE lp64)
if("iomp5" IN_LIST OpenMP_C_LIB_NAMES)
set(MKL_THREADING intel_thread)
elseif("gomp" IN_LIST OpenMP_C_LIB_NAMES)
set(MKL_THREADING gnu_thread)
else()
set(MKL_THREADING sequential)
endif()
if (HAVE_MPI)
message(STATUS "THESE ARE THE MPI CXX LIBRARIES ${MPI_CXX_LIBRARIES}")
if("${MPI_CXX_LIBRARIES}" MATCHES openmpi)
set(MKL_MPI openmpi)
elseif("${MPI_CXX_LIBRARIES}" MATCHES intel)
set(MKL_MPI intelmpi)
elseif("${MPI_CXX_LIBRARIES}" MATCHES mpich)
set(MKL_MPI mpich)
else()
set(MKL_MPI "${ASSUME_MPI_VENDOR}")
message(WARNING "Unknown MPI when setting up MKL, assuming ${ASSUME_MPI_VENDOR}")
endif()
set(ENABLE_CDFT ON)
set(ENABLE_BLACS ON)
endif()
set(MKL_LINK dynamic)
find_package(MKL CONFIG REQUIRED PATHS $ENV{MKLROOT})
set(MKL_INCLUDE_DIR ${MKL_ROOT}/include)
message(STATUS "Imported MKL targets: ${MKL_IMPORTED_TARGETS}")
message(STATUS "MKL include path is: ${MKL_INCLUDE_DIR}")
if(NOT MKL_FOUND)
message(FATAL_ERROR "MKL not found")
else()
string(APPEND CMAKE_SHARED_LINKER_FLAGS " -Wl,--no-as-needed")
endif()
message(STATUS "LINKER FLAGS ${CMAKE_SHARED_LINKER_FLAGS}")
set(FFT_BACKEND_NUMBER 0)
else()
message(STATUS "BLAS libraries: ${BLAS_LIBRARIES}")
if (NOT BLAS_LIBRARIES)
find_package(BLAS)
check_function_exists(ffsll HAVE_FFS)
endif()

set(FFTW_CONFIGURE_ARGS --enable-static=no --enable-shared=yes --enable-threads --enable-openmp)
if (BUILD_MARCH_NATIVE)
list(APPEND FFTW_CONFIGURE_ARGS --enable-avx --enable-sse2 --enable-avx2 --enable-avx512)
endif()
if (HAVE_MPI)
list(APPEND FFTW_CONFIGURE_ARGS --enable-mpi=yes MPILIBS=-lmpi)
endif()

if (NOT BLAS_LIBRARIES)
message(FATAL_ERROR "A required library with BLAS API not found.")
else()
message(STATUS "BLAS libraries: ${BLAS_LIBRARIES}")
endif()
set(FFT_BACKEND_NUMBER 1)

include(ExternalProject)
if(BUILD_FFTW)
ExternalProject_Add(libfftw3
URL https://www.fftw.org/fftw-3.3.10.tar.gz
PREFIX ${PROJECT_BINARY_DIR}/deps
INSTALL_DIR ${PROJECT_SOURCE_DIR}/deps
BUILD_IN_SOURCE True
CONFIGURE_COMMAND ./configure ${FFTW_CONFIGURE_ARGS} CXX=${CMAKE_CXX_COMPILER} CC=${CMAKE_C_COMPILER} prefix=<INSTALL_DIR>
BUILD_COMMAND make -j4 install
)
endif()
endif()

find_package(Python REQUIRED COMPONENTS Interpreter Development)

include_directories(${PROJECT_SOURCE_DIR})
include_directories(${PROJECT_SOURCE_DIR}/deps/include)
include_directories(${CMAKE_INSTALL_PREFIX}/include)
Expand All @@ -106,28 +149,18 @@ set(CMAKE_BUILD_WITH_INSTALL_RPATH True)
set(CMAKE_INSTALL_RPATH "\$ORIGIN:\$ORIGIN/deps/lib:\$ORIGIN/deps/lib64")
message(RPATH=${CMAKE_INSTALL_RPATH})

include(ExternalProject)
option(ENABLE_FFTW "Using fftw3" ON)
option(BUILD_FFTW "Building fftw3" ON)

string(REPLACE ":" ";" _lib_path "$ENV{LD_LIBRARY_PATH}")
find_library(LIBXC_LIBRARIES NAMES xc PATHS ${_lib_path})
add_subdirectory(mod_cider)
add_subdirectory(numint_cider)
add_subdirectory(pwutil)
add_subdirectory(sbt)
add_subdirectory(xc_utils)

if(ENABLE_FFTW AND BUILD_FFTW)
ExternalProject_Add(libfftw3
URL https://www.fftw.org/fftw-3.3.10.tar.gz
PREFIX ${PROJECT_BINARY_DIR}/deps
INSTALL_DIR ${PROJECT_SOURCE_DIR}/deps
BUILD_IN_SOURCE True
CONFIGURE_COMMAND ./configure ${FFTW_CONFIGURE_ARGS} CXX=${CMAKE_CXX_COMPILER} CC=${CMAKE_C_COMPILER} prefix=<INSTALL_DIR>
BUILD_COMMAND make -j4 install
)
add_dependencies(pwutil libfftw3)
add_subdirectory(fft_wrapper)
add_dependencies(mcider fft_wrapper)
add_dependencies(sbt fft_wrapper)
if (BUILD_FFTW AND NOT BUILD_WITH_MKL)
add_dependencies(fft_wrapper libfftw3)
endif()

# NOTE some other stuff from pyscf CMake file was here
Expand Down
90 changes: 90 additions & 0 deletions ciderpress/lib/fft_plan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import ctypes

import numpy as np

from ciderpress.lib import load_library

libfft = load_library("libfft_wrapper.so")

libfft.allocate_fftnd_plan.restype = ctypes.c_void_p
libfft.malloc_fft_plan_in_array.restype = ctypes.c_void_p
libfft.malloc_fft_plan_out_array.restype = ctypes.c_void_p
libfft.get_fft_plan_in_array.restype = ctypes.c_void_p
libfft.get_fft_plan_out_array.restype = ctypes.c_void_p
libfft.get_fft_input_size.restype = ctypes.c_int
libfft.get_fft_output_size.restype = ctypes.c_int


class FFTWrapper:
def __init__(
self, dims, ntransform=1, fwd=True, r2c=False, inplace=False, batch_first=True
):
self._dims = dims
self._ntransform = ntransform
self._fwd = fwd
self._r2c = r2c
self._inplace = inplace
self._batch_first = batch_first
rshape = [d for d in dims]
if r2c:
kshape = [d for d in dims[:-1]] + [dims[-1] // 2 + 1]
else:
kshape = [d for d in dims]
if batch_first:
rshape.insert(0, self._ntransform)
kshape.insert(0, self._ntransform)
else:
rshape.append(self._ntransform)
kshape.append(self._ntransform)
if fwd:
self._inshape = tuple(rshape)
self._outshape = tuple(kshape)
else:
self._inshape = tuple(kshape)
self._outshape = tuple(rshape)
dims = np.asarray(dims, dtype=np.int32)
self._ptr = ctypes.c_void_p(
libfft.allocate_fftnd_plan(
ctypes.c_int(len(dims)),
dims.ctypes.data_as(ctypes.c_void_p),
ctypes.c_int(1 if fwd else 0),
ctypes.c_int(1 if r2c else 0),
ctypes.c_int(self._ntransform),
ctypes.c_int(1 if inplace else 0),
ctypes.c_int(1 if batch_first else 0),
)
)
self._in = ctypes.c_void_p(libfft.malloc_fft_plan_in_array(self._ptr))
if self._inplace:
self._out = None
else:
self._out = ctypes.c_void_p(libfft.malloc_fft_plan_out_array(self._ptr))
libfft.initialize_fft_plan(
self._ptr,
self._in,
self._out,
)

def __del__(self):
libfft.free_fft_plan(self._ptr)
libfft.free_fft_array(self._in)
if not self._inplace:
libfft.free_fft_array(self._out)

@property
def input_shape(self):
return self._inshape

@property
def output_shape(self):
return self._outshape

def call(self, x):
if x.shape != self._inshape:
raise ValueError(f"Expected input of shape {self._inshape}, got {x.shape}")
dtype = np.float64 if (self._r2c and not self._fwd) else np.complex128
out = np.empty(self._outshape, dtype=dtype)
libfft.write_fft_input(self._ptr, x.ctypes.data_as(ctypes.c_void_p))
libfft.execute_fft_plan(self._ptr)
libfft.read_fft_output(self._ptr, out.ctypes.data_as(ctypes.c_void_p))
return out
34 changes: 34 additions & 0 deletions ciderpress/lib/fft_wrapper/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@

add_library(fft_wrapper SHARED
cider_fft.c
cider_mpi_fft.c
)

set_target_properties(fft_wrapper PROPERTIES
LIBRARY_OUTPUT_DIRECTORY ${PROJECT_SOURCE_DIR}
COMPILE_FLAGS ${OpenMP_C_FLAGS}
LINK_FLAGS ${OpenMP_C_FLAGS})

configure_file(
${PROJECT_SOURCE_DIR}/fft_wrapper/config.h.in
${PROJECT_SOURCE_DIR}/fft_wrapper/cider_fft_config.h
NEWLINE_STYLE UNIX
)

if (HAVE_MPI)
target_link_libraries(fft_wrapper PUBLIC MPI::MPI_C)
endif()

if (FFT_BACKEND_NUMBER EQUAL 0)
target_include_directories(fft_wrapper PUBLIC ${MKL_INCLUDE_DIR})
target_link_libraries(fft_wrapper PUBLIC MKL::MKL)
if (HAVE_MPI)
target_link_libraries(fft_wrapper PUBLIC MKL::MKL_CDFT)
endif()
else()
target_link_libraries(fft_wrapper PRIVATE fftw3)
target_link_libraries(fft_wrapper PRIVATE fftw3_omp)
if (HAVE_MPI)
target_link_libraries(fft_wrapper PRIVATE fftw3_mpi)
endif()
endif()
Loading
Loading