Skip to content

Commit

Permalink
avoid using WarpX::GetInstance() in AcceleratorLattice
Browse files Browse the repository at this point in the history
  • Loading branch information
lucafedeli88 committed Mar 10, 2025
1 parent acb7289 commit cbb2d27
Show file tree
Hide file tree
Showing 8 changed files with 40 additions and 25 deletions.
8 changes: 6 additions & 2 deletions Source/AcceleratorLattice/AcceleratorLattice.H
Original file line number Diff line number Diff line change
Expand Up @@ -43,21 +43,24 @@ public:
*
* @param[in] lev the level of refinement
* @param[in] gamma_boost the Lorentz factor of the boosted frame
* @param[in] time the current time at all refinement levels
* @param[in] ba the box array at the level of refinement
* @param[in] dm the distribution map at the level of refinement
*/
void InitElementFinder (
int lev,
amrex::Real gamma_boost,
const amrex::Vector<amrex::Real>& time,
amrex::BoxArray const & ba,
amrex::DistributionMapping const & dm);

/**
* \brief Update the element finder, needed when the simulation frame has moved relative to the lab frame
*
* @param[in] lev the level of refinement
* @param[in] time the current time at all refinement levels
*/
void UpdateElementFinder (int lev);
void UpdateElementFinder (int lev, const amrex::Vector<amrex::Real>& time);

/* The lattice element finder handles the lookup that finds the elements at the particle locations.
* It should follow the same grid layout as the main grids.
Expand All @@ -70,9 +73,10 @@ public:
*
* @param[in] a_pti the grid where the finder is needed
* @param[in] a_offset the particle offset since the finded needs information about the particles as well
* @param[in] dts vector containing the timestep sizes at all refinement levels
*/
[[nodiscard]] LatticeElementFinderDevice
GetFinderDeviceInstance (WarpXParIter const& a_pti, int a_offset) const;
GetFinderDeviceInstance (WarpXParIter const& a_pti, int a_offset, const amrex::Vector<amrex::Real>& dts) const;

/* All of the available lattice element types */
Drift h_drift;
Expand Down
11 changes: 6 additions & 5 deletions Source/AcceleratorLattice/AcceleratorLattice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,32 +78,33 @@ AcceleratorLattice::ReadLattice (std::string const & root_name, amrex::ParticleR
void
AcceleratorLattice::InitElementFinder (
int const lev, amrex::Real const gamma_boost,
const amrex::Vector<amrex::Real>& time,
amrex::BoxArray const & ba, amrex::DistributionMapping const & dm)
{
if (m_lattice_defined) {
m_element_finder = std::make_unique<amrex::LayoutData<LatticeElementFinder>>(ba, dm);
for (amrex::MFIter mfi(*m_element_finder); mfi.isValid(); ++mfi)
{
(*m_element_finder)[mfi].InitElementFinder(lev, gamma_boost, mfi, *this);
(*m_element_finder)[mfi].InitElementFinder(lev, gamma_boost, time, mfi, *this);
}
}
}

void
AcceleratorLattice::UpdateElementFinder (int const lev) // NOLINT(readability-make-member-function-const)
AcceleratorLattice::UpdateElementFinder (int const lev, const amrex::Vector<amrex::Real>& time) // NOLINT(readability-make-member-function-const)
{ // Techniquely clang-tidy is correct because
// m_element_finder is unique_ptr, not const*.
if (m_lattice_defined) {
for (amrex::MFIter mfi(*m_element_finder); mfi.isValid(); ++mfi)
{
(*m_element_finder)[mfi].UpdateIndices(lev, mfi, *this);
(*m_element_finder)[mfi].UpdateIndices(lev, mfi, *this, time);
}
}
}

LatticeElementFinderDevice
AcceleratorLattice::GetFinderDeviceInstance (WarpXParIter const& a_pti, int const a_offset) const
AcceleratorLattice::GetFinderDeviceInstance (WarpXParIter const& a_pti, int const a_offset, const amrex::Vector<amrex::Real>& dts) const
{
const LatticeElementFinder & finder = (*m_element_finder)[a_pti];
return finder.GetFinderDeviceInstance(a_pti, a_offset, *this);
return finder.GetFinderDeviceInstance(a_pti, a_offset, *this, dts);
}
14 changes: 11 additions & 3 deletions Source/AcceleratorLattice/LatticeElementFinder.H
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,12 @@ struct LatticeElementFinder
*
* @param[in] lev the refinement level
* @param[in] gamma_boost the Lorentz factor of the boosted frame
* @param[in] time the current time on all refinement levels
* @param[in] a_mfi specifies the grid where the finder is defined
* @param[in] accelerator_lattice a reference to the accelerator lattice at the refinement level
*/
void InitElementFinder (int lev, amrex::Real gamma_boost,
const amrex::Vector<amrex::Real>& time,
amrex::MFIter const& a_mfi,
AcceleratorLattice const& accelerator_lattice);

Expand All @@ -51,9 +53,11 @@ struct LatticeElementFinder
* @param[in] lev the refinement level
* @param[in] a_mfi specifies the grid where the finder is defined
* @param[in] accelerator_lattice a reference to the accelerator lattice at the refinement level
* @param[in] time the current time on all refinement levels
*/
void UpdateIndices (int lev, amrex::MFIter const& a_mfi,
AcceleratorLattice const& accelerator_lattice);
AcceleratorLattice const& accelerator_lattice,
const amrex::Vector<amrex::Real>& time);

/* Define the location and size of the index lookup table */
/* Use the type Real to be consistent with the way the main grid is defined */
Expand All @@ -73,9 +77,11 @@ struct LatticeElementFinder
* @param[in] a_pti specifies the grid where the finder is defined
* @param[in] a_offset particle index offset needed to access particle info
* @param[in] accelerator_lattice a reference to the accelerator lattice at the refinement level
* @param[in] dts vector containing the timestep sizes at all refinement levels
*/
[[nodiscard]] LatticeElementFinderDevice GetFinderDeviceInstance (
WarpXParIter const& a_pti, int a_offset, AcceleratorLattice const& accelerator_lattice) const;
WarpXParIter const& a_pti, int a_offset, AcceleratorLattice const& accelerator_lattice,
const amrex::Vector<amrex::Real>& dts) const;

/* The index lookup tables for each lattice element type */
amrex::Gpu::DeviceVector<int> d_quad_indices;
Expand Down Expand Up @@ -108,11 +114,13 @@ struct LatticeElementFinderDevice
* @param[in] a_offset particle index offset needed to access particle info
* @param[in] accelerator_lattice a reference to the accelerator lattice at the refinement level
* @param[in] h_finder The host level instance of the element finder that this is associated with
* @param[in] dts vector containing the timestep sizes at all refinement levels
*/
void
InitLatticeElementFinderDevice (WarpXParIter const& a_pti, int a_offset,
AcceleratorLattice const& accelerator_lattice,
LatticeElementFinder const & h_finder);
LatticeElementFinder const & h_finder,
const amrex::Vector<amrex::Real>& dts);

/* Whether the class has been initialized */
bool m_initialized = false;
Expand Down
22 changes: 11 additions & 11 deletions Source/AcceleratorLattice/LatticeElementFinder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ using namespace amrex::literals;

void
LatticeElementFinder::InitElementFinder (int const lev, const amrex::Real gamma_boost,
const amrex::Vector<amrex::Real>& time,
amrex::MFIter const& a_mfi,
AcceleratorLattice const& accelerator_lattice)
{
Expand All @@ -32,7 +33,7 @@ LatticeElementFinder::InitElementFinder (int const lev, const amrex::Real gamma_

AllocateIndices(accelerator_lattice);

UpdateIndices(lev, a_mfi, accelerator_lattice);
UpdateIndices(lev, a_mfi, accelerator_lattice, time);

}

Expand All @@ -53,16 +54,15 @@ LatticeElementFinder::AllocateIndices (AcceleratorLattice const& accelerator_lat

void
LatticeElementFinder::UpdateIndices (int const lev, amrex::MFIter const& a_mfi,
AcceleratorLattice const& accelerator_lattice)
AcceleratorLattice const& accelerator_lattice,
const amrex::Vector<amrex::Real>& time)
{
auto& warpx = WarpX::GetInstance();

// Update the location of the index grid.
// Note that the current box is used since the box may have been updated since
// the initialization in InitElementFinder.
const amrex::Box box = a_mfi.tilebox();
m_zmin = WarpX::LowerCorner(box, lev, 0._rt).z;
m_time = warpx.gett_new(lev);
m_time = time[lev];

if (accelerator_lattice.h_quad.nelements > 0) {
setup_lattice_indices(accelerator_lattice.h_quad.d_zs,
Expand All @@ -79,31 +79,31 @@ LatticeElementFinder::UpdateIndices (int const lev, amrex::MFIter const& a_mfi,

LatticeElementFinderDevice
LatticeElementFinder::GetFinderDeviceInstance (WarpXParIter const& a_pti, int const a_offset,
AcceleratorLattice const& accelerator_lattice) const
AcceleratorLattice const& accelerator_lattice,
const amrex::Vector<amrex::Real>& dts) const
{
LatticeElementFinderDevice result;
result.InitLatticeElementFinderDevice(a_pti, a_offset, accelerator_lattice, *this);
result.InitLatticeElementFinderDevice(a_pti, a_offset, accelerator_lattice, *this, dts);
return result;
}

void
LatticeElementFinderDevice::InitLatticeElementFinderDevice (WarpXParIter const& a_pti, int const a_offset,
AcceleratorLattice const& accelerator_lattice,
LatticeElementFinder const & h_finder)
LatticeElementFinder const & h_finder,
const amrex::Vector<amrex::Real>& dts)
{

m_initialized = true;

auto& warpx = WarpX::GetInstance();

int const lev = a_pti.GetLevel();

m_get_position = GetParticlePosition<PIdx>(a_pti, a_offset);
const auto& attribs = a_pti.GetAttribs();
m_ux = attribs[PIdx::ux].dataPtr() + a_offset;
m_uy = attribs[PIdx::uy].dataPtr() + a_offset;
m_uz = attribs[PIdx::uz].dataPtr() + a_offset;
m_dt = warpx.getdt(lev);
m_dt = dts[lev];

m_gamma_boost = WarpX::gamma_boost;
m_uz_boost = std::sqrt(WarpX::gamma_boost*WarpX::gamma_boost - 1._prt)*PhysConst::c;
Expand Down
2 changes: 1 addition & 1 deletion Source/Evolve/WarpXEvolve.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ WarpX::Evolve (int numsteps)
// from either a moving window or a boosted frame
if (num_moved != 0 || gamma_boost > 1) {
for (int lev = 0; lev <= finest_level; ++lev) {
m_accelerator_lattice[lev]->UpdateElementFinder(lev);
m_accelerator_lattice[lev]->UpdateElementFinder(lev, gett_new());
}
}

Expand Down
2 changes: 1 addition & 1 deletion Source/Parallelization/WarpXRegrid.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ WarpX::RemakeLevel (int lev, Real /*time*/, const BoxArray& ba, const Distributi
}

// Re-initialize the lattice element finder with the new ba and dm.
m_accelerator_lattice[lev]->InitElementFinder(lev, gamma_boost, ba, dm);
m_accelerator_lattice[lev]->InitElementFinder(lev, gamma_boost, gett_new(), ba, dm);

if (costs[lev] != nullptr)
{
Expand Down
4 changes: 3 additions & 1 deletion Source/Particles/Gather/GetExternalFields.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@ GetExternalEBField::GetExternalEBField (const WarpXParIter& a_pti, long a_offset

const int lev = a_pti.GetLevel();

const auto& dts = warpx.getdt();

AcceleratorLattice const & accelerator_lattice = warpx.get_accelerator_lattice(lev);
if (accelerator_lattice.m_lattice_defined) {
d_lattice_element_finder = accelerator_lattice.GetFinderDeviceInstance(a_pti, static_cast<int>(a_offset));
d_lattice_element_finder = accelerator_lattice.GetFinderDeviceInstance(a_pti, static_cast<int>(a_offset), dts);
}

m_gamma_boost = WarpX::gamma_boost;
Expand Down
2 changes: 1 addition & 1 deletion Source/WarpX.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2158,7 +2158,7 @@ WarpX::AllocLevelData (int lev, const BoxArray& ba, const DistributionMapping& d
guard_cells.ng_alloc_Rho, guard_cells.ng_alloc_F, guard_cells.ng_alloc_G, aux_is_nodal);

m_accelerator_lattice[lev] = std::make_unique<AcceleratorLattice>();
m_accelerator_lattice[lev]->InitElementFinder(lev, gamma_boost, ba, dm);
m_accelerator_lattice[lev]->InitElementFinder(lev, gamma_boost, gett_new(), ba, dm);

}

Expand Down

0 comments on commit cbb2d27

Please sign in to comment.