Skip to content

Commit

Permalink
Code available.
Browse files Browse the repository at this point in the history
  • Loading branch information
mhofmann-qc committed Aug 2, 2021
1 parent bcbc3a1 commit c372939
Show file tree
Hide file tree
Showing 37 changed files with 6,407 additions and 4 deletions.
23 changes: 23 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# PyCharm
.idea

# MAC OS
.DS_Store

# pytest
.coverage
.pytest
.pytest_cache

# Python
*__pycache__*
*.pth

# Redundant files
.nfs*

# Log files
log

# Trash
*.nfs*
13 changes: 13 additions & 0 deletions LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
Copyright (c) 2021 Qualcomm Technologies, Inc.

All rights reserved.

Redistribution and use in source and binary forms, with or without modification, are permitted (subject to the limitations in the disclaimer below) provided that the following conditions are met:

* Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer:

* Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.

* Neither the name of Qualcomm Technologies, Inc. nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.

NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
92 changes: 88 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,88 @@
# InverseForm: A Loss Function for Structured Boundary-Aware Segmentation
Paper: [arXiv](https://arxiv.org/abs/2104.02745)

The codebase will be available here soon.
# InverseForm

This repository provides the InverseForm module.

Shubhankar Borse, Ying Wang, Yizhe Zhang, Fatih Porikli, "InverseForm: A Loss Function for Structured Boundary-Aware Segmentation
", CVPR 2021.[[arxiv]](https://arxiv.org/abs/2104.02745)

Qualcomm AI Research (Qualcomm AI Research is an initiative of Qualcomm Technologies, Inc)

## Reference
If you find our work useful for your research, please cite:
```latex
@inproceedings{borse2021inverseform,
title={InverseForm: A Loss Function for Structured Boundary-Aware Segmentation},
author={Borse, Shubhankar and Wang, Ying and Zhang, Yizhe and Porikli, Fatih
},
booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
year={2021}
}
```

## Method
InverseForm is a novel boundary-aware loss term for semantic segmentation, which efficiently learns the degree of parametric transformations between estimated and target boundaries.

![! an image](display/inverseform_framework.png)

This plug-in loss term complements the cross-entropy loss in capturing boundary transformations and allows consistent and significant performance improvement on segmentation backbone models without increasing their size and computational complexity.

Here is an example demo from our state-of-the-art model trained on the Cityscapes benchmark.

<img src="display/if_photos_gif.gif " width="425"/> <img src="display/if_labels_gif.gif " width="425"/>

This repository contains the implementation of InverseForm module presented in the paper. It can also run inference on Cityscapes validation set with models trained using the InverseForm framework. The same models can be validated by removing the InverseForm framework such that no additional compute is added during inference. Here are some of the models over which you can run inference with and without the InverseForm block (right-most column of the table below):



| Model | mIoU (trained w/o InverseForm) | mIoU (trained w/ InverseForm) |
| :-------------: | :-----------------------------: | :-----------------------------: |
| HRNet-18 | 77.0% | 77.6% |
| OCRNet-48 | 86.0% | 86.3% |
| OCRNet-48-HMS | 86.7% | 87.0% |


## Setup environment

Code has been tested with pytorch 1.3 and NVIDIA Apex. The Dockerfile is available under docker/ folder.

## Cityscapes path

utils/config.py has the dataset/directory information. Please update CITYSCAPES_DIR as the preferred Cityscapes directory. You can download this dataset from https://www.cityscapes-dataset.com/.

## Inference on cityscapes

To run inference, this directory path needs to be added to your pythonpath. Here is the command for this:

```bash
export PYTHONPATH="${PYTHONPATH}:/path/to/this/dir"
```

Here are code snippets to run inference on the models shown above. These examples show usage with 8 GPUs. You could run the inference command with 1/2/4 GPUs by updating the nproc_per_node argument.

*Checkpoints coming soon!*

* HRNet-18-IF
```bash
python -m torch.distributed.launch --nproc_per_node=8 experiment/validation.py --output_dir "/path/to/output/dir" --model_path "checkpoints/hrnet18_IF_checkpoint.pth" --has_edge True
```
* OCRNet-48-IF
```bash
python -m torch.distributed.launch --nproc_per_node=8 experiment/validation.py --output_dir "/path/to/output/dir" --model_path checkpoints/hrnet48_OCR_IF_checkpoint.pth --arch "ocrnet.HRNet" --hrnet_base "48" --has_edge True
```
* HMS-OCRNet-48-IF
```bash
python -m torch.distributed.launch --nproc_per_node=8 experiment/validation.py --output_dir "/path/to/output/dir" --model_path checkpoints/hrnet48_OCR_HMS_IF_checkpoint.pth --arch "ocrnet.HRNet_Mscale" --hrnet_base "48" --has_edge True
```

To remove the InverseForm operation during inference, simply run without the has_edge flag. You will notice no drop in performance as compared to running with the operation.

## Acknowledgements:

This repository shares code with the following repositories:

* Hierarchical Multi-Scale Attention(HMS):
https://github.com/NVIDIA/semantic-segmentation
* HRNet-OCR: https://github.com/HRNet/HRNet-Semantic-Segmentation

We would like to acknowledge the researchers who made these repositories open-source.

Binary file added display/if_labels_gif.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added display/if_photos_gif.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added display/inverseform_framework.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
30 changes: 30 additions & 0 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
FROM nvcr.io/nvidia/pytorch:19.10-py3

RUN pip install numpy
RUN pip install runx==0.0.6
RUN pip install sklearn
RUN pip install h5py
RUN pip install jupyter
RUN pip install scikit-image
RUN pip install pillow
RUN pip install piexif
RUN pip install cffi
RUN pip install tqdm
RUN pip install dominate
RUN pip install opencv-python
RUN pip install nose
RUN pip install ninja
RUN pip install fire

RUN apt-get update
RUN apt-get install libgtk2.0-dev -y && rm -rf /var/lib/apt/lists/*

# Install Apex
RUN cd /home/ && git clone https://github.com/NVIDIA/apex.git apex && cd apex && python setup.py install --cuda_ext --cpp_ext
WORKDIR /home/

RUN apt-get update \
&& apt-get install -y wget curl sudo software-properties-common

# Add sudo support
RUN echo "%users ALL = (ALL) NOPASSWD: ALL" >> /etc/sudoers
130 changes: 130 additions & 0 deletions experiment/validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# Copyright (c) 2021 Qualcomm Technologies, Inc.

# All Rights Reserved.

from __future__ import absolute_import
from __future__ import division
from apex import amp
from runx.logx import logx
import numpy as np
import torch
import argparse
import os
import sys
import time
import fire
from utils.config import assert_and_infer_cfg, cfg
from utils.misc import AverageMeter, eval_metrics
from utils.misc import ImageDumper
from utils.trnval_utils import eval_minibatch
from utils.progress_bar import printProgressBar
from models.loss.utils import get_loss
from models.model_loader import load_model
from library.datasets.get_dataloaders import return_dataloader
import models
import warnings

if not sys.warnoptions:
warnings.simplefilter("ignore")

torch.backends.cudnn.benchmark = True


def set_apex_params(local_rank):
"""
Setting distributed parameters for Apex
"""
if 'WORLD_SIZE' in os.environ:
world_size = int(os.environ['WORLD_SIZE'])
global_rank = int(os.environ['RANK'])

print('GPU {} has Rank {}'.format(
local_rank, global_rank))
torch.cuda.set_device(local_rank)
torch.distributed.init_process_group(backend='nccl',
init_method='env://')
return world_size, global_rank


def inference(val_loader, net, arch, loss_fn, epoch, calc_metrics=True):
"""
Inference over dataloader on network
"""

len_dataset = len(val_loader)
net.eval()
val_loss = AverageMeter()
iou_acc = 0

for val_idx, data in enumerate(val_loader):
input_images, labels, edge, img_names, _ = data

# Run network
assets, _iou_acc = \
eval_minibatch(data, net, loss_fn, val_loss, calc_metrics,
val_idx)
iou_acc += _iou_acc
if val_idx+1 < len_dataset:
printProgressBar(val_idx + 1, len_dataset, 'Progress')

logx.msg("\n")
if calc_metrics:
eval_metrics(iou_acc, net, val_loss, epoch, arch)


def main(output_dir, model_path, has_edge=False, model_summary=False, arch='ocrnet.AuxHRNet',
hrnet_base='18', num_workers=4, split='val', batch_size=2, crop_size='1024,2048',
apex=True, syncbn=True, fp16=True, local_rank=0):

#Distributed processing
if apex:
world_size, global_rank = set_apex_params(local_rank)
else:
world_size = 1
global_rank = 0
local_rank = 0

#Logging
logx.initialize(logdir=output_dir,
tensorboard=True,
global_rank=global_rank)

#Build config
assert_and_infer_cfg(output_dir, global_rank, apex, syncbn, arch, hrnet_base,
fp16, has_edge)

#Dataloader
val_loader = return_dataloader(num_workers, batch_size)

#Loss function
loss_fn = get_loss(has_edge)

assert model_path is not None, 'need pytorch model for inference'

#Load Network
checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
logx.msg("Loading weights from: {}".format(model_path))
net = models.get_net(arch, loss_fn)
if fp16:
net = amp.initialize(net, opt_level='O1', verbosity=0)
net = models.wrap_network_in_dataparallel(net, apex)
#restore_net(net, checkpoint, arch)
load_model(net, checkpoint)
#Summary of MAC/#param
if model_summary:
from thop import profile
img = torch.randn(1, 3, 1024, 2048).cuda()
mask = torch.randn(1, 1, 1024, 2048).cuda()
macs, params = profile(net, inputs=({'images': img, 'gts': mask}, ))
print(f'macs {macs} params {params}')
sys.exit()


torch.cuda.empty_cache()

#Run inference
inference(val_loader, net, arch, loss_fn, epoch=0)


if __name__ == '__main__':
fire.Fire(main)
18 changes: 18 additions & 0 deletions library/data/cityscapes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import os
import os.path as path
from utils.config import cfg
import library.data.cityscapes_labels as cityscapes_labels


def find_directories(root):
"""
Find folders in validation set.
"""
trn_path = path.join(root, 'leftImg8bit', 'train')
val_path = path.join(root, 'leftImg8bit', 'val')

trn_directories = ['train/' + c for c in os.listdir(trn_path)]
trn_directories = sorted(trn_directories) # sort to insure reproducibility
val_directories = ['val/' + c for c in os.listdir(val_path)]

return val_directories
Loading

0 comments on commit c372939

Please sign in to comment.