-
Notifications
You must be signed in to change notification settings - Fork 22
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
bcbc3a1
commit c372939
Showing
37 changed files
with
6,407 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
# PyCharm | ||
.idea | ||
|
||
# MAC OS | ||
.DS_Store | ||
|
||
# pytest | ||
.coverage | ||
.pytest | ||
.pytest_cache | ||
|
||
# Python | ||
*__pycache__* | ||
*.pth | ||
|
||
# Redundant files | ||
.nfs* | ||
|
||
# Log files | ||
log | ||
|
||
# Trash | ||
*.nfs* |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
Copyright (c) 2021 Qualcomm Technologies, Inc. | ||
|
||
All rights reserved. | ||
|
||
Redistribution and use in source and binary forms, with or without modification, are permitted (subject to the limitations in the disclaimer below) provided that the following conditions are met: | ||
|
||
* Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer: | ||
|
||
* Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. | ||
|
||
* Neither the name of Qualcomm Technologies, Inc. nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. | ||
|
||
NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,88 @@ | ||
# InverseForm: A Loss Function for Structured Boundary-Aware Segmentation | ||
Paper: [arXiv](https://arxiv.org/abs/2104.02745) | ||
|
||
The codebase will be available here soon. | ||
# InverseForm | ||
|
||
This repository provides the InverseForm module. | ||
|
||
Shubhankar Borse, Ying Wang, Yizhe Zhang, Fatih Porikli, "InverseForm: A Loss Function for Structured Boundary-Aware Segmentation | ||
", CVPR 2021.[[arxiv]](https://arxiv.org/abs/2104.02745) | ||
|
||
Qualcomm AI Research (Qualcomm AI Research is an initiative of Qualcomm Technologies, Inc) | ||
|
||
## Reference | ||
If you find our work useful for your research, please cite: | ||
```latex | ||
@inproceedings{borse2021inverseform, | ||
title={InverseForm: A Loss Function for Structured Boundary-Aware Segmentation}, | ||
author={Borse, Shubhankar and Wang, Ying and Zhang, Yizhe and Porikli, Fatih | ||
}, | ||
booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition}, | ||
year={2021} | ||
} | ||
``` | ||
|
||
## Method | ||
InverseForm is a novel boundary-aware loss term for semantic segmentation, which efficiently learns the degree of parametric transformations between estimated and target boundaries. | ||
|
||
 | ||
|
||
This plug-in loss term complements the cross-entropy loss in capturing boundary transformations and allows consistent and significant performance improvement on segmentation backbone models without increasing their size and computational complexity. | ||
|
||
Here is an example demo from our state-of-the-art model trained on the Cityscapes benchmark. | ||
|
||
<img src="display/if_photos_gif.gif " width="425"/> <img src="display/if_labels_gif.gif " width="425"/> | ||
|
||
This repository contains the implementation of InverseForm module presented in the paper. It can also run inference on Cityscapes validation set with models trained using the InverseForm framework. The same models can be validated by removing the InverseForm framework such that no additional compute is added during inference. Here are some of the models over which you can run inference with and without the InverseForm block (right-most column of the table below): | ||
|
||
|
||
|
||
| Model | mIoU (trained w/o InverseForm) | mIoU (trained w/ InverseForm) | | ||
| :-------------: | :-----------------------------: | :-----------------------------: | | ||
| HRNet-18 | 77.0% | 77.6% | | ||
| OCRNet-48 | 86.0% | 86.3% | | ||
| OCRNet-48-HMS | 86.7% | 87.0% | | ||
|
||
|
||
## Setup environment | ||
|
||
Code has been tested with pytorch 1.3 and NVIDIA Apex. The Dockerfile is available under docker/ folder. | ||
|
||
## Cityscapes path | ||
|
||
utils/config.py has the dataset/directory information. Please update CITYSCAPES_DIR as the preferred Cityscapes directory. You can download this dataset from https://www.cityscapes-dataset.com/. | ||
|
||
## Inference on cityscapes | ||
|
||
To run inference, this directory path needs to be added to your pythonpath. Here is the command for this: | ||
|
||
```bash | ||
export PYTHONPATH="${PYTHONPATH}:/path/to/this/dir" | ||
``` | ||
|
||
Here are code snippets to run inference on the models shown above. These examples show usage with 8 GPUs. You could run the inference command with 1/2/4 GPUs by updating the nproc_per_node argument. | ||
|
||
*Checkpoints coming soon!* | ||
|
||
* HRNet-18-IF | ||
```bash | ||
python -m torch.distributed.launch --nproc_per_node=8 experiment/validation.py --output_dir "/path/to/output/dir" --model_path "checkpoints/hrnet18_IF_checkpoint.pth" --has_edge True | ||
``` | ||
* OCRNet-48-IF | ||
```bash | ||
python -m torch.distributed.launch --nproc_per_node=8 experiment/validation.py --output_dir "/path/to/output/dir" --model_path checkpoints/hrnet48_OCR_IF_checkpoint.pth --arch "ocrnet.HRNet" --hrnet_base "48" --has_edge True | ||
``` | ||
* HMS-OCRNet-48-IF | ||
```bash | ||
python -m torch.distributed.launch --nproc_per_node=8 experiment/validation.py --output_dir "/path/to/output/dir" --model_path checkpoints/hrnet48_OCR_HMS_IF_checkpoint.pth --arch "ocrnet.HRNet_Mscale" --hrnet_base "48" --has_edge True | ||
``` | ||
|
||
To remove the InverseForm operation during inference, simply run without the has_edge flag. You will notice no drop in performance as compared to running with the operation. | ||
|
||
## Acknowledgements: | ||
|
||
This repository shares code with the following repositories: | ||
|
||
* Hierarchical Multi-Scale Attention(HMS): | ||
https://github.com/NVIDIA/semantic-segmentation | ||
* HRNet-OCR: https://github.com/HRNet/HRNet-Semantic-Segmentation | ||
|
||
We would like to acknowledge the researchers who made these repositories open-source. | ||
|
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
FROM nvcr.io/nvidia/pytorch:19.10-py3 | ||
|
||
RUN pip install numpy | ||
RUN pip install runx==0.0.6 | ||
RUN pip install sklearn | ||
RUN pip install h5py | ||
RUN pip install jupyter | ||
RUN pip install scikit-image | ||
RUN pip install pillow | ||
RUN pip install piexif | ||
RUN pip install cffi | ||
RUN pip install tqdm | ||
RUN pip install dominate | ||
RUN pip install opencv-python | ||
RUN pip install nose | ||
RUN pip install ninja | ||
RUN pip install fire | ||
|
||
RUN apt-get update | ||
RUN apt-get install libgtk2.0-dev -y && rm -rf /var/lib/apt/lists/* | ||
|
||
# Install Apex | ||
RUN cd /home/ && git clone https://github.com/NVIDIA/apex.git apex && cd apex && python setup.py install --cuda_ext --cpp_ext | ||
WORKDIR /home/ | ||
|
||
RUN apt-get update \ | ||
&& apt-get install -y wget curl sudo software-properties-common | ||
|
||
# Add sudo support | ||
RUN echo "%users ALL = (ALL) NOPASSWD: ALL" >> /etc/sudoers |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
# Copyright (c) 2021 Qualcomm Technologies, Inc. | ||
|
||
# All Rights Reserved. | ||
|
||
from __future__ import absolute_import | ||
from __future__ import division | ||
from apex import amp | ||
from runx.logx import logx | ||
import numpy as np | ||
import torch | ||
import argparse | ||
import os | ||
import sys | ||
import time | ||
import fire | ||
from utils.config import assert_and_infer_cfg, cfg | ||
from utils.misc import AverageMeter, eval_metrics | ||
from utils.misc import ImageDumper | ||
from utils.trnval_utils import eval_minibatch | ||
from utils.progress_bar import printProgressBar | ||
from models.loss.utils import get_loss | ||
from models.model_loader import load_model | ||
from library.datasets.get_dataloaders import return_dataloader | ||
import models | ||
import warnings | ||
|
||
if not sys.warnoptions: | ||
warnings.simplefilter("ignore") | ||
|
||
torch.backends.cudnn.benchmark = True | ||
|
||
|
||
def set_apex_params(local_rank): | ||
""" | ||
Setting distributed parameters for Apex | ||
""" | ||
if 'WORLD_SIZE' in os.environ: | ||
world_size = int(os.environ['WORLD_SIZE']) | ||
global_rank = int(os.environ['RANK']) | ||
|
||
print('GPU {} has Rank {}'.format( | ||
local_rank, global_rank)) | ||
torch.cuda.set_device(local_rank) | ||
torch.distributed.init_process_group(backend='nccl', | ||
init_method='env://') | ||
return world_size, global_rank | ||
|
||
|
||
def inference(val_loader, net, arch, loss_fn, epoch, calc_metrics=True): | ||
""" | ||
Inference over dataloader on network | ||
""" | ||
|
||
len_dataset = len(val_loader) | ||
net.eval() | ||
val_loss = AverageMeter() | ||
iou_acc = 0 | ||
|
||
for val_idx, data in enumerate(val_loader): | ||
input_images, labels, edge, img_names, _ = data | ||
|
||
# Run network | ||
assets, _iou_acc = \ | ||
eval_minibatch(data, net, loss_fn, val_loss, calc_metrics, | ||
val_idx) | ||
iou_acc += _iou_acc | ||
if val_idx+1 < len_dataset: | ||
printProgressBar(val_idx + 1, len_dataset, 'Progress') | ||
|
||
logx.msg("\n") | ||
if calc_metrics: | ||
eval_metrics(iou_acc, net, val_loss, epoch, arch) | ||
|
||
|
||
def main(output_dir, model_path, has_edge=False, model_summary=False, arch='ocrnet.AuxHRNet', | ||
hrnet_base='18', num_workers=4, split='val', batch_size=2, crop_size='1024,2048', | ||
apex=True, syncbn=True, fp16=True, local_rank=0): | ||
|
||
#Distributed processing | ||
if apex: | ||
world_size, global_rank = set_apex_params(local_rank) | ||
else: | ||
world_size = 1 | ||
global_rank = 0 | ||
local_rank = 0 | ||
|
||
#Logging | ||
logx.initialize(logdir=output_dir, | ||
tensorboard=True, | ||
global_rank=global_rank) | ||
|
||
#Build config | ||
assert_and_infer_cfg(output_dir, global_rank, apex, syncbn, arch, hrnet_base, | ||
fp16, has_edge) | ||
|
||
#Dataloader | ||
val_loader = return_dataloader(num_workers, batch_size) | ||
|
||
#Loss function | ||
loss_fn = get_loss(has_edge) | ||
|
||
assert model_path is not None, 'need pytorch model for inference' | ||
|
||
#Load Network | ||
checkpoint = torch.load(model_path, map_location=torch.device('cpu')) | ||
logx.msg("Loading weights from: {}".format(model_path)) | ||
net = models.get_net(arch, loss_fn) | ||
if fp16: | ||
net = amp.initialize(net, opt_level='O1', verbosity=0) | ||
net = models.wrap_network_in_dataparallel(net, apex) | ||
#restore_net(net, checkpoint, arch) | ||
load_model(net, checkpoint) | ||
#Summary of MAC/#param | ||
if model_summary: | ||
from thop import profile | ||
img = torch.randn(1, 3, 1024, 2048).cuda() | ||
mask = torch.randn(1, 1, 1024, 2048).cuda() | ||
macs, params = profile(net, inputs=({'images': img, 'gts': mask}, )) | ||
print(f'macs {macs} params {params}') | ||
sys.exit() | ||
|
||
|
||
torch.cuda.empty_cache() | ||
|
||
#Run inference | ||
inference(val_loader, net, arch, loss_fn, epoch=0) | ||
|
||
|
||
if __name__ == '__main__': | ||
fire.Fire(main) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
import os | ||
import os.path as path | ||
from utils.config import cfg | ||
import library.data.cityscapes_labels as cityscapes_labels | ||
|
||
|
||
def find_directories(root): | ||
""" | ||
Find folders in validation set. | ||
""" | ||
trn_path = path.join(root, 'leftImg8bit', 'train') | ||
val_path = path.join(root, 'leftImg8bit', 'val') | ||
|
||
trn_directories = ['train/' + c for c in os.listdir(trn_path)] | ||
trn_directories = sorted(trn_directories) # sort to insure reproducibility | ||
val_directories = ['val/' + c for c in os.listdir(val_path)] | ||
|
||
return val_directories |
Oops, something went wrong.