Skip to content

Commit

Permalink
change mask use for nice-slam
Browse files Browse the repository at this point in the history
  • Loading branch information
wangxiaomeng030 committed Apr 20, 2024
1 parent e170267 commit ada03ef
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 38 deletions.
17 changes: 5 additions & 12 deletions slam/algorithms/base_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def pre_precessing(self, cur_frame, is_mapping):
pass

@abstractmethod
def post_processing(self, step, is_mapping, optimizer=None):
def post_processing(self, step, is_mapping, optimizer=None, coarse=False):
pass

@abstractmethod
Expand Down Expand Up @@ -232,16 +232,6 @@ def do_mapping(self, cur_frame):
is_mapping=True,
coarse=False)

# only for nice-slam
if self.config.coarse:
# do coarse_mapper
optimize_frames = self.select_optimize_frames(
cur_frame, keyframe_selection_method='random')
self.optimize_update(mapping_n_iters,
optimize_frames,
is_mapping=True,
coarse=True)

if not self.is_initialized():
self.set_initialized()

Expand Down Expand Up @@ -274,7 +264,10 @@ def optimize_update(self,
).clone().cpu().numpy()
loss.backward(
retain_graph=(self.config.retain_graph and is_mapping))
self.post_processing(step, is_mapping, optimizers.optimizers)
self.post_processing(step,
is_mapping,
optimizers.optimizers,
coarse=coarse)
optimizers.optimizer_step_all(step=step)
optimizers.scheduler_step_all()
# return best c2w by min_loss
Expand Down
17 changes: 11 additions & 6 deletions slam/algorithms/nice_slam.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,12 +100,13 @@ def do_mapping(self, cur_frame):
coarse=False)

# do coarse_mapper
optimize_frames = self.select_optimize_frames(
cur_frame, keyframe_selection_method='random')
self.optimize_update(mapping_n_iters,
optimize_frames,
is_mapping=True,
coarse=True)
if self.config.coarse:
optimize_frames = self.select_optimize_frames(
cur_frame, keyframe_selection_method='random')
self.optimize_update(mapping_n_iters,
optimize_frames,
is_mapping=True,
coarse=True)

if not self.is_initialized():
self.set_initialized()
Expand Down Expand Up @@ -133,6 +134,10 @@ def pre_precessing(self, cur_frame, is_mapping):
if is_mapping:
self.model.pre_precessing(cur_frame)

def post_processing(self, step, is_mapping, optimizer=None, coarse=False):
if is_mapping:
self.model.post_processing(coarse)

def get_model_input(self, optimize_frames, is_mapping):
batch_rays_d_list = []
batch_rays_o_list = []
Expand Down
2 changes: 1 addition & 1 deletion slam/algorithms/splatam.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def pre_precessing(self, cur_frame, is_mapping):
if is_mapping:
self.model.model_update(cur_frame)

def post_processing(self, step, is_mapping, optimizer=None):
def post_processing(self, step, is_mapping, optimizer=None, coarse=False):
if is_mapping:
self.model.post_processing(step, optimizer)

Expand Down
12 changes: 4 additions & 8 deletions slam/model_components/decoder_nice.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,17 +207,14 @@ def sample_grid_feature(self, p, c):
def forward(self, p, c_grid=None):
if self.c_dim != 0:
c = self.sample_grid_feature(
p,
c_grid['grid_' + self.name].val_mask()).transpose(1,
2).squeeze(0)
p, c_grid['grid_' + self.name]).transpose(1, 2).squeeze(0)

if self.concat_feature:
# only happen to fine decoder, get feature from middle
# level and concat to the current feature
with torch.no_grad():
c_middle = self.sample_grid_feature(
p, c_grid['grid_middle'].val_mask()).transpose(
1, 2).squeeze(0)
p, c_grid['grid_middle']).transpose(1, 2).squeeze(0)
c = torch.cat([c, c_middle], dim=1)

p = p.float()
Expand Down Expand Up @@ -309,9 +306,8 @@ def sample_grid_feature(self, p, grid_feature):
return c

def forward(self, p, c_grid, **kwargs):
c = self.sample_grid_feature(
p, c_grid['grid_' + self.name].val_mask()).transpose(1,
2).squeeze(0)
c = self.sample_grid_feature(p, c_grid['grid_' + self.name]).transpose(
1, 2).squeeze(0)
h = c
for i, l in enumerate(self.pts_linears):
h = self.pts_linears[i](h)
Expand Down
53 changes: 42 additions & 11 deletions slam/models/conv_onet.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,21 +89,44 @@ def populate_modules(self):
self.load_pretrain()
self.grid_init()
self.grid_opti_mask = {}
self.masked_c_grad = {}

def post_processing(self, coarse):
if self.config.mapping_frustum_feature_selection:
for key, val in self.grid_c.items():
if (coarse and key == 'grid_coarse') or (not coarse and
key != 'grid_coarse'):
val_grad = self.masked_c_grad[key]
mask = self.masked_c_grad[key + 'mask']
val = val.detach()
val[mask] = val_grad.clone().detach()
self.grid_c[key] = val

def grid_processing(self, coarse):
if self.config.mapping_frustum_feature_selection:
for key, val in self.grid_c.items():
if (coarse and key == 'grid_coarse') or (not coarse and
key != 'grid_coarse'):
val_grad = self.masked_c_grad[key]
mask = self.masked_c_grad[key + 'mask']
val = val.to(self.device)
val[mask] = val_grad
self.grid_c[key] = val

def pre_precessing(self, cur_frame):
if self.config.mapping_frustum_feature_selection:
gt_depth_np = cur_frame.depth
c2w = cur_frame.get_pose()
for key, grid in self.grid_c.items():
for key, val in self.grid_c.items():
mask = get_mask_from_c2w(camera=self.camera,
bound=self.bounding_box,
c2w=c2w,
key=key,
val_shape=grid.val.shape[2:],
val_shape=val.shape[2:],
depth_np=gt_depth_np)
mask = torch.from_numpy(mask).permute(
2, 1, 0).unsqueeze(0).unsqueeze(0).repeat(
1, grid.val.shape[1], 1, 1, 1)
1, val.shape[1], 1, 1, 1)
self.grid_opti_mask[key] = mask

def get_outputs(self, input) -> Dict[str, Union[torch.Tensor, List]]:
Expand Down Expand Up @@ -171,12 +194,20 @@ def get_param_groups(self) -> Dict[str, List[Parameter]]:
if len(decoders_para_list) > 0:
param_groups['decoder'] = decoders_para_list
# grid_params
for key, grid in self.grid_c.items():
grid = grid.to(self.device)
for key, val in self.grid_c.items():
if self.config.mapping_frustum_feature_selection:
val = val.to(self.device)
mask = self.grid_opti_mask[key]
grid.set_mask(mask)
param_groups[key] = list(grid.parameters())
val_grad = val[mask].clone()
val_grad = torch.nn.Parameter(val_grad.to(self.device))
self.masked_c_grad[key] = val_grad
self.masked_c_grad[key + 'mask'] = mask
param_groups[key] = [val_grad]
else:
val = torch.nn.Parameter(val.to(self.device))
self.grid_c[key] = val
param_groups[key] = [val]

return param_groups

# only used by mesher
Expand Down Expand Up @@ -239,25 +270,25 @@ def grid_init(self):
xyz_len=xyz_len * self.config.model_coarse_bound_enlarge,
grid_len=coarse_grid_len,
c_dim=c_dim,
std=0.01)
std=0.01).val

middle_key = 'grid_middle'
self.grid_c[middle_key] = FeatureGrid(xyz_len=xyz_len,
grid_len=middle_grid_len,
c_dim=c_dim,
std=0.01)
std=0.01).val

fine_key = 'grid_fine'
self.grid_c[fine_key] = FeatureGrid(xyz_len=xyz_len,
grid_len=fine_grid_len,
c_dim=c_dim,
std=0.0001)
std=0.0001).val

color_key = 'grid_color'
self.grid_c[color_key] = FeatureGrid(xyz_len=xyz_len,
grid_len=color_grid_len,
c_dim=c_dim,
std=0.01)
std=0.01).val

def load_pretrain(self):
"""This function is modified from nice-slam, licensed under the Apache
Expand Down

0 comments on commit ada03ef

Please sign in to comment.