Skip to content

Commit

Permalink
test: label propagation with DistanceTransform
Browse files Browse the repository at this point in the history
  • Loading branch information
bogovicj committed Feb 27, 2025
1 parent 19cb4ec commit 1a99586
Showing 1 changed file with 186 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,14 @@

package net.imglib2.algorithm.morphology.distance;

import static org.junit.Assert.assertTrue;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Random;
import java.util.Set;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
Expand All @@ -46,6 +52,7 @@
import org.junit.Test;

import net.imglib2.Cursor;
import net.imglib2.Interval;
import net.imglib2.Localizable;
import net.imglib2.Point;
import net.imglib2.RandomAccess;
Expand All @@ -58,16 +65,20 @@
import net.imglib2.img.basictypeaccess.array.DoubleArray;
import net.imglib2.img.basictypeaccess.array.LongArray;
import net.imglib2.type.logic.BitType;
import net.imglib2.type.numeric.IntegerType;
import net.imglib2.type.numeric.RealType;
import net.imglib2.type.numeric.integer.LongType;
import net.imglib2.type.numeric.real.DoubleType;
import net.imglib2.util.Intervals;
import net.imglib2.util.Localizables;
import net.imglib2.util.Pair;
import net.imglib2.util.Util;
import net.imglib2.view.Views;

/**
*
* @author Philipp Hanslovsky
*
* @author John Bogovic
*/
public class DistanceTransformTest
{
Expand Down Expand Up @@ -159,9 +170,7 @@ private void testBinary( final DISTANCE_TYPE dt, final DistanceCalculator distan

private static void compareRAIofRealType( final RandomAccessibleInterval< ? extends RealType< ? > > ref, final RandomAccessibleInterval< ? extends RealType< ? > > comp, final double tolerance )
{
Assert.assertArrayEquals( Intervals.dimensionsAsLongArray( ref ), Intervals.dimensionsAsLongArray( comp ) );
Assert.assertArrayEquals( Intervals.minAsLongArray( ref ), Intervals.minAsLongArray( comp ) );
Assert.assertArrayEquals( Intervals.maxAsLongArray( ref ), Intervals.maxAsLongArray( comp ) );
assertTrue( Intervals.equals( ref, comp ) );
for ( final Pair< ? extends RealType< ? >, ? extends RealType< ? > > p : Views.flatIterable( Views.interval( Views.pair( ref, comp ), ref ) ) )
{
Assert.assertEquals( p.getA().getRealDouble(), p.getB().getRealDouble(), tolerance );
Expand Down Expand Up @@ -440,12 +449,184 @@ private static < T extends RealType< T > > void checkDistance(
final double[] weights,
final DistanceCalculator distanceCalculator )
{
for ( final Cursor< T > c = Views.iterable( dist ).localizingCursor(); c.hasNext(); )
for ( final Cursor< T > c = dist.localizingCursor(); c.hasNext(); )
{
final double actual = c.next().getRealDouble();
final double expected = atSamePosition( foreground, c ) ? 0.0 : distanceCalculator.dist( foreground, c, weights );
Assert.assertEquals( expected, actual, 0.0 );
}
}

@Test
public void testLabelPropagation()
{
/*
* Iterate over numReplicates = [0..9] numDimensions = [2, 3] numLabels
* = [1..5]
*/
final int firstReplicate = 0;
final int lastReplicate = 9;

final int firstNumDimensions = 2;
final int lastNumDimensions = 3;

final int firstNumLabels = 2;
final int lastNumLabels = 5;

final RandomAccessibleInterval< Localizable > parameters = Localizables.randomAccessibleInterval(
Intervals.createMinMax(
firstReplicate, firstNumDimensions,
firstNumLabels, lastReplicate,
lastNumDimensions, lastNumLabels ) );

parameters.forEach( params -> {

@SuppressWarnings( "unused" )
final int replicate = params.getIntPosition( 0 );
final int numDimensions = params.getIntPosition( 1 );
final int numLabels = params.getIntPosition( 2 );

testLabelPropagationHelper( numDimensions, numLabels );
testLabelPropagationHelperParallel( numDimensions, numLabels );
} );

}

/**
* Creates an label and distances images with the requested number of dimensions (ndims),
* and places nLabels points with non-zero label. Checks that the propagated labels correctly
* reflect the nearest label (ties are allowed: any label equi-distant to a point passes).
*
* @param ndims number of dimensions
* @param nLabels number of labels
*/
private void testLabelPropagationHelper( int ndims, int nLabels )
{

final long[] imgDims = LongStream.iterate( dimensionSize, d -> d - 1 ).limit( ndims ).toArray();
final ArrayImg< LongType, LongArray > labels = ArrayImgs.longs( imgDims );

final Set< PointAndLabel > points = initializeLabels( rng, nLabels, labels );
DistanceTransform.labelTransform( labels, 0 );
validateLabelsSet( "serial", points, labels );
}

/**
* Creates an label and distances images with the requested number of dimensions (ndims),
* and places nLabels points with non-zero label. Checks that the propagated labels correctly
* reflect the nearest label (ties are allowed: any label equi-distant to a point passes).
*
* @param ndims number of dimensions
* @param nLabels number of labels
*/
private void testLabelPropagationHelperParallel( int ndims, int nLabels )
{

final long[] imgDims = LongStream.iterate( dimensionSize, d -> d - 1 ).limit( ndims ).toArray();
final ArrayImg< LongType, LongArray > labels = ArrayImgs.longs( imgDims );
final Set< PointAndLabel > points = initializeLabels( rng, nLabels, labels );
DistanceTransform.labelTransform( labels, 0, es, 3 * nThreads );
validateLabelsSet( "parallel", points, labels );
}

private ArrayImg< LongType, LongArray > copyLongArrayImg( ArrayImg< LongType, LongArray > img )
{

final long[] dataOrig = img.getAccessType().getCurrentStorageArray();
final long[] dataCopy = new long[ dataOrig.length ];
System.arraycopy( dataOrig, 0, dataCopy, 0, dataOrig.length );
return ArrayImgs.longs( dataCopy, img.dimensionsAsLongArray() );
}

private static Point randomPointInInterval( final Random rng, final Interval itvl )
{
final int[] coords = IntStream.range( 0, itvl.numDimensions() ).map( i -> {
return rng.nextInt( ( int ) itvl.dimension( i ) );
} ).toArray();
return new Point( coords );
}

private static < T extends RealType< T >, L extends IntegerType< L > > Set< PointAndLabel > initializeLabels( Random random, int numLabels, RandomAccessibleInterval< L > labels )
{
labels.forEach( p -> p.setZero() ); // Initialize all labels to 0
Set< PointAndLabel > positions = new HashSet<>();

int currentLabel = 1;
// Set numLabels different random positions to a non-zero label
while ( positions.size() < numLabels )
{
final Point pt = randomPointInInterval( random, labels );
if ( !positions.contains( pt ) )
{

final PointAndLabel candidate = new PointAndLabel( currentLabel, pt.positionAsLongArray() );
if ( !positions.contains( candidate ) )
{
positions.add( candidate );
labels.randomAccess().setPositionAndGet( pt ).setInteger( currentLabel );
currentLabel++;
}

}
}
return positions;
}

/**
* Return the set of points within epsilon distance of the query point
*
* @param query point
* @param pointSet set of candidate points
* @param epsilon distance threshold
* @return the set of close points
*/
private static List< PointAndLabel > closestSet( Localizable query, Set< PointAndLabel > pointSet, final double epsilon )
{

final List< PointAndLabel > listOfEquidistant = new ArrayList<>();

double mindist = Double.MAX_VALUE;
for ( PointAndLabel pt : pointSet )
{
double dist = Util.distance( query, pt );

if ( Math.abs( dist - mindist ) < epsilon )
{
listOfEquidistant.add( pt );
}
else if ( dist < mindist )
{
mindist = dist;
listOfEquidistant.clear();
listOfEquidistant.add( pt );
}
}

return listOfEquidistant;
}

private static < T extends RealType< T >, L extends IntegerType< L > > void validateLabelsSet( final String prefix, final Set< PointAndLabel > points, final RandomAccessibleInterval< L > labels )
{
final double EPS = 0.01;
final Cursor< L > c = labels.cursor();
while ( c.hasNext() )
{
c.fwd();
final boolean labelIsClosest = closestSet( c, points, EPS ).stream().anyMatch( p -> p.label == c.get().getIntegerLong() );
assertTrue( prefix + " point: " + Arrays.toString( c.positionAsLongArray() ), labelIsClosest );
}
}

private static class PointAndLabel extends Point
{

long label;

public PointAndLabel( long label, long[] position )
{
super( position );
this.label = label;
}
}

}

0 comments on commit 1a99586

Please sign in to comment.