diff --git a/slam/algorithms/base_algorithm.py b/slam/algorithms/base_algorithm.py index a30c053..9def539 100644 --- a/slam/algorithms/base_algorithm.py +++ b/slam/algorithms/base_algorithm.py @@ -75,7 +75,7 @@ def pre_precessing(self, cur_frame, is_mapping): pass @abstractmethod - def post_processing(self, step, is_mapping, optimizer=None): + def post_processing(self, step, is_mapping, optimizer=None, coarse=False): pass @abstractmethod @@ -232,16 +232,6 @@ def do_mapping(self, cur_frame): is_mapping=True, coarse=False) - # only for nice-slam - if self.config.coarse: - # do coarse_mapper - optimize_frames = self.select_optimize_frames( - cur_frame, keyframe_selection_method='random') - self.optimize_update(mapping_n_iters, - optimize_frames, - is_mapping=True, - coarse=True) - if not self.is_initialized(): self.set_initialized() @@ -274,7 +264,10 @@ def optimize_update(self, ).clone().cpu().numpy() loss.backward( retain_graph=(self.config.retain_graph and is_mapping)) - self.post_processing(step, is_mapping, optimizers.optimizers) + self.post_processing(step, + is_mapping, + optimizers.optimizers, + coarse=coarse) optimizers.optimizer_step_all(step=step) optimizers.scheduler_step_all() # return best c2w by min_loss diff --git a/slam/algorithms/nice_slam.py b/slam/algorithms/nice_slam.py index 2a0b674..7a89cd5 100644 --- a/slam/algorithms/nice_slam.py +++ b/slam/algorithms/nice_slam.py @@ -100,12 +100,13 @@ def do_mapping(self, cur_frame): coarse=False) # do coarse_mapper - optimize_frames = self.select_optimize_frames( - cur_frame, keyframe_selection_method='random') - self.optimize_update(mapping_n_iters, - optimize_frames, - is_mapping=True, - coarse=True) + if self.config.coarse: + optimize_frames = self.select_optimize_frames( + cur_frame, keyframe_selection_method='random') + self.optimize_update(mapping_n_iters, + optimize_frames, + is_mapping=True, + coarse=True) if not self.is_initialized(): self.set_initialized() @@ -133,6 +134,10 @@ def pre_precessing(self, cur_frame, is_mapping): if is_mapping: self.model.pre_precessing(cur_frame) + def post_processing(self, step, is_mapping, optimizer=None, coarse=False): + if is_mapping: + self.model.post_processing(coarse) + def get_model_input(self, optimize_frames, is_mapping): batch_rays_d_list = [] batch_rays_o_list = [] diff --git a/slam/algorithms/splatam.py b/slam/algorithms/splatam.py index 89ce563..0274753 100644 --- a/slam/algorithms/splatam.py +++ b/slam/algorithms/splatam.py @@ -82,7 +82,7 @@ def pre_precessing(self, cur_frame, is_mapping): if is_mapping: self.model.model_update(cur_frame) - def post_processing(self, step, is_mapping, optimizer=None): + def post_processing(self, step, is_mapping, optimizer=None, coarse=False): if is_mapping: self.model.post_processing(step, optimizer) diff --git a/slam/model_components/decoder_nice.py b/slam/model_components/decoder_nice.py index 355e3c2..ac49476 100755 --- a/slam/model_components/decoder_nice.py +++ b/slam/model_components/decoder_nice.py @@ -207,17 +207,14 @@ def sample_grid_feature(self, p, c): def forward(self, p, c_grid=None): if self.c_dim != 0: c = self.sample_grid_feature( - p, - c_grid['grid_' + self.name].val_mask()).transpose(1, - 2).squeeze(0) + p, c_grid['grid_' + self.name]).transpose(1, 2).squeeze(0) if self.concat_feature: # only happen to fine decoder, get feature from middle # level and concat to the current feature with torch.no_grad(): c_middle = self.sample_grid_feature( - p, c_grid['grid_middle'].val_mask()).transpose( - 1, 2).squeeze(0) + p, c_grid['grid_middle']).transpose(1, 2).squeeze(0) c = torch.cat([c, c_middle], dim=1) p = p.float() @@ -309,9 +306,8 @@ def sample_grid_feature(self, p, grid_feature): return c def forward(self, p, c_grid, **kwargs): - c = self.sample_grid_feature( - p, c_grid['grid_' + self.name].val_mask()).transpose(1, - 2).squeeze(0) + c = self.sample_grid_feature(p, c_grid['grid_' + self.name]).transpose( + 1, 2).squeeze(0) h = c for i, l in enumerate(self.pts_linears): h = self.pts_linears[i](h) diff --git a/slam/models/conv_onet.py b/slam/models/conv_onet.py index 9619319..fef8663 100644 --- a/slam/models/conv_onet.py +++ b/slam/models/conv_onet.py @@ -89,21 +89,44 @@ def populate_modules(self): self.load_pretrain() self.grid_init() self.grid_opti_mask = {} + self.masked_c_grad = {} + + def post_processing(self, coarse): + if self.config.mapping_frustum_feature_selection: + for key, val in self.grid_c.items(): + if (coarse and key == 'grid_coarse') or (not coarse and + key != 'grid_coarse'): + val_grad = self.masked_c_grad[key] + mask = self.masked_c_grad[key + 'mask'] + val = val.detach() + val[mask] = val_grad.clone().detach() + self.grid_c[key] = val + + def grid_processing(self, coarse): + if self.config.mapping_frustum_feature_selection: + for key, val in self.grid_c.items(): + if (coarse and key == 'grid_coarse') or (not coarse and + key != 'grid_coarse'): + val_grad = self.masked_c_grad[key] + mask = self.masked_c_grad[key + 'mask'] + val = val.to(self.device) + val[mask] = val_grad + self.grid_c[key] = val def pre_precessing(self, cur_frame): if self.config.mapping_frustum_feature_selection: gt_depth_np = cur_frame.depth c2w = cur_frame.get_pose() - for key, grid in self.grid_c.items(): + for key, val in self.grid_c.items(): mask = get_mask_from_c2w(camera=self.camera, bound=self.bounding_box, c2w=c2w, key=key, - val_shape=grid.val.shape[2:], + val_shape=val.shape[2:], depth_np=gt_depth_np) mask = torch.from_numpy(mask).permute( 2, 1, 0).unsqueeze(0).unsqueeze(0).repeat( - 1, grid.val.shape[1], 1, 1, 1) + 1, val.shape[1], 1, 1, 1) self.grid_opti_mask[key] = mask def get_outputs(self, input) -> Dict[str, Union[torch.Tensor, List]]: @@ -171,12 +194,20 @@ def get_param_groups(self) -> Dict[str, List[Parameter]]: if len(decoders_para_list) > 0: param_groups['decoder'] = decoders_para_list # grid_params - for key, grid in self.grid_c.items(): - grid = grid.to(self.device) + for key, val in self.grid_c.items(): if self.config.mapping_frustum_feature_selection: + val = val.to(self.device) mask = self.grid_opti_mask[key] - grid.set_mask(mask) - param_groups[key] = list(grid.parameters()) + val_grad = val[mask].clone() + val_grad = torch.nn.Parameter(val_grad.to(self.device)) + self.masked_c_grad[key] = val_grad + self.masked_c_grad[key + 'mask'] = mask + param_groups[key] = [val_grad] + else: + val = torch.nn.Parameter(val.to(self.device)) + self.grid_c[key] = val + param_groups[key] = [val] + return param_groups # only used by mesher @@ -239,25 +270,25 @@ def grid_init(self): xyz_len=xyz_len * self.config.model_coarse_bound_enlarge, grid_len=coarse_grid_len, c_dim=c_dim, - std=0.01) + std=0.01).val middle_key = 'grid_middle' self.grid_c[middle_key] = FeatureGrid(xyz_len=xyz_len, grid_len=middle_grid_len, c_dim=c_dim, - std=0.01) + std=0.01).val fine_key = 'grid_fine' self.grid_c[fine_key] = FeatureGrid(xyz_len=xyz_len, grid_len=fine_grid_len, c_dim=c_dim, - std=0.0001) + std=0.0001).val color_key = 'grid_color' self.grid_c[color_key] = FeatureGrid(xyz_len=xyz_len, grid_len=color_grid_len, c_dim=c_dim, - std=0.01) + std=0.01).val def load_pretrain(self): """This function is modified from nice-slam, licensed under the Apache