Skip to content

Commit

Permalink
Update fix structor to contain multiple annotation changes (#242)
Browse files Browse the repository at this point in the history
This PR introduces a necessary change in preparation for the upcoming major PR that will add support for taint checker inference.

The key change in this PR is the update to the `Fix` structure, which now contains a list of changes rather than a single change. With this new design, a `Fix` instance represents the smallest unit of annotation changes that the Annotator will evaluate for its impact. If approved, these changes will be applied to the source code. It’s important to note that resolving an error may require adding multiple annotations, and a `Fix` instance can include a subset or all of these changes.

In NullAway inference, we evaluate each annotation individually and apply it if approved. For example, resolving an initialization error may require annotating multiple fields as `@Nullable`, but each annotation is evaluated and applied one at a time. In contrast, taint inference requires evaluating all suggested annotations together for a given error. 

This updated design lays the groundwork for supporting the taint checker inference by accommodating the evaluation of grouped annotations for each error.
  • Loading branch information
nimakarimipour authored Oct 8, 2024
1 parent a60d7d2 commit 5e3437e
Show file tree
Hide file tree
Showing 17 changed files with 205 additions and 164 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ private void executeNextIteration(
injector.injectFixes(selectedFixes);
// Update log.
context.log.updateInjectedAnnotations(
selectedFixes.stream().map(fix -> fix.change).collect(Collectors.toSet()));
selectedFixes.stream().flatMap(fix -> fix.changes.stream()).collect(Collectors.toSet()));
// Update impact saved state.
downstreamImpactCache.updateImpactsAfterInjection(selectedFixes);
targetModuleCache.updateImpactsAfterInjection(selectedFixes);
Expand Down
15 changes: 4 additions & 11 deletions annotator-core/src/main/java/edu/ucr/cs/riple/core/Report.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
import edu.ucr.cs.riple.core.module.ModuleInfo;
import edu.ucr.cs.riple.core.registries.index.Error;
import edu.ucr.cs.riple.core.registries.index.Fix;
import edu.ucr.cs.riple.injector.location.Location;
import java.util.HashSet;
import java.util.Objects;
import java.util.Set;
Expand Down Expand Up @@ -180,22 +179,16 @@ public boolean testEquals(Config config, Report found) {
}
this.tree.add(this.root);
found.tree.add(found.root);
Set<Location> thisTree = this.tree.stream().map(Fix::toLocation).collect(Collectors.toSet());
Set<Location> otherTree = found.tree.stream().map(Fix::toLocation).collect(Collectors.toSet());
if (!thisTree.equals(otherTree)) {
if (!this.tree.equals(found.tree)) {
return false;
}
Set<Location> thisTriggered =
Set<Fix> thisTriggered =
this.triggeredErrors.stream()
.filter(Error::hasFix)
.flatMap(error -> error.getResolvingFixes().stream())
.map(Fix::toLocation)
.collect(Collectors.toSet());
Set<Location> otherTriggered =
Set<Fix> otherTriggered =
found.triggeredErrors.stream()
.filter(Error::hasFix)
.flatMap(error -> error.getResolvingFixes().stream())
.map(Fix::toLocation)
.collect(Collectors.toSet());
return otherTriggered.equals(thisTriggered);
}
Expand All @@ -207,7 +200,7 @@ public String toString() {
+ ", "
+ root
+ ", "
+ tree.stream().map(Fix::toLocation).collect(Collectors.toSet());
+ tree.stream().map(Fix::toLocations).collect(Collectors.toSet());
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import com.google.common.collect.ImmutableSet;
import edu.ucr.cs.riple.core.registries.index.Error;
import edu.ucr.cs.riple.core.registries.index.Fix;
import edu.ucr.cs.riple.injector.location.Location;
import java.util.Collection;
import java.util.Map;
import java.util.Objects;
Expand All @@ -39,8 +38,7 @@
* @param <T> type of impacts saved in this model.
* @param <S> type of the map used to store impacts.
*/
public abstract class BaseCache<T extends Impact, S extends Map<Location, T>>
implements ImpactCache<T> {
public abstract class BaseCache<T extends Impact, S extends Map<Fix, T>> implements ImpactCache<T> {

/** Container holding cache entries. */
protected final S store;
Expand All @@ -51,19 +49,19 @@ public BaseCache(S store) {

@Override
public boolean isUnknown(Fix fix) {
return !this.store.containsKey(fix.toLocation());
return !this.store.containsKey(fix);
}

@Nullable
@Override
public T fetchImpact(Fix fix) {
return store.get(fix.toLocation());
return store.get(fix);
}

@Override
public ImmutableSet<Error> getTriggeredErrorsForCollection(Collection<Fix> fixes) {
return fixes.stream()
.map(fix -> store.get(fix.toLocation()))
.map(store::get)
.filter(Objects::nonNull)
.flatMap(impact -> impact.triggeredErrors.stream())
// filter errors that will be resolved with the existing collection of fixes.
Expand All @@ -74,7 +72,7 @@ public ImmutableSet<Error> getTriggeredErrorsForCollection(Collection<Fix> fixes
@Override
public ImmutableSet<Fix> getTriggeredFixesFromDownstreamForCollection(Collection<Fix> fixTree) {
return fixTree.stream()
.map(fix -> store.get(fix.toLocation()))
.map(store::get)
.filter(Objects::nonNull)
.flatMap(impact -> impact.getTriggeredFixesFromDownstreamErrors().stream())
// filter fixes that are already inside tree.
Expand All @@ -95,6 +93,6 @@ public void updateImpactsAfterInjection(Collection<Fix> fixes) {

@Override
public int size() {
return this.store.values().size();
return this.store.size();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import com.google.common.collect.ImmutableSet;
import edu.ucr.cs.riple.core.registries.index.Error;
import edu.ucr.cs.riple.core.registries.index.Fix;
import edu.ucr.cs.riple.injector.location.Location;
import java.util.Collection;
import java.util.Objects;
import java.util.Set;
Expand Down Expand Up @@ -90,15 +89,6 @@ public ImmutableSet<Fix> getTriggeredFixesFromDownstreamErrors() {
return triggeredFixesFromDownstreamErrors;
}

/**
* Gets the containing location.
*
* @return Containing fix location.
*/
public Location toLocation() {
return fix.toLocation();
}

@Override
public int hashCode() {
return fix.hashCode();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,15 @@

package edu.ucr.cs.riple.core.cache;

import edu.ucr.cs.riple.injector.location.Location;
import edu.ucr.cs.riple.core.registries.index.Fix;
import java.util.HashMap;
import java.util.Set;

/**
* Cache for storing impacts of fixes on target module. This cache's state is not immutable and can
* be updated.
*/
public class TargetModuleCache extends BaseCache<Impact, HashMap<Location, Impact>> {
public class TargetModuleCache extends BaseCache<Impact, HashMap<Fix, Impact>> {

public TargetModuleCache() {
super(new HashMap<>());
Expand All @@ -44,6 +44,6 @@ public TargetModuleCache() {
* @param newData New given impacts.
*/
public void updateCacheState(Set<Impact> newData) {
newData.forEach(t -> store.put(t.toLocation(), t));
newData.forEach(t -> store.put(t.fix, t));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
* once created, cannot be updated.
*/
public class DownstreamImpactCacheImpl
extends BaseCache<DownstreamImpact, Map<Location, DownstreamImpact>>
extends BaseCache<DownstreamImpact, Map<Fix, DownstreamImpact>>
implements DownstreamImpactCache {

/** Annotator context instance. */
Expand Down Expand Up @@ -130,7 +130,7 @@ public void analyzeDownstreamDependencies() {
reports.forEach(
report -> {
DownstreamImpact impact = new DownstreamImpact(report);
store.put(report.root.toLocation(), impact);
store.put(report.root, impact);
});
System.out.println("Analyzing downstream dependencies completed!");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ protected void collectGraphResults(ImmutableSet<Report> reports) {
node.triggeredErrors.stream()
.filter(
error ->
error.isSingleFix()
error.isSingleAnnotationFix()
// Method is declared in the target module.
&& context.targetModuleInfo.declaredInModule(
error.toResolvingLocation()))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,29 +114,30 @@ private NullAwayError deserializeErrorFromTSVLine(ModuleInfo moduleInfo, String
Location nonnullTarget =
Location.createLocationFromArrayInfo(Arrays.copyOfRange(values, 6, 12));
if (nonnullTarget == null && errorType.equals(NullAwayError.METHOD_INITIALIZER_ERROR)) {
ImmutableSet<Fix> resolvingFixes =
generateFixesForUninitializedFields(errorMessage, region, moduleInfo);
Set<AddAnnotation> annotationsOnField =
computeAddAnnotationInstancesForUninitializedFields(
errorMessage, region.clazz, moduleInfo);
return createError(
errorType,
errorMessage,
region,
context.offsetHandler.getOriginalOffset(path, offset),
resolvingFixes,
annotationsOnField,
moduleInfo);
}
if (nonnullTarget != null && nonnullTarget.isOnField()) {
nonnullTarget = extendVariableList(nonnullTarget.toField(), moduleInfo);
}
Fix resolvingFix =
Set<AddAnnotation> annotations =
nonnullTarget == null
? null
: new Fix(new AddMarkerAnnotation(nonnullTarget, config.nullableAnnot));
? Set.of()
: Set.of(new AddMarkerAnnotation(nonnullTarget, config.nullableAnnot));
return createError(
errorType,
errorMessage,
region,
context.offsetHandler.getOriginalOffset(path, offset),
resolvingFix == null ? ImmutableSet.of() : ImmutableSet.of(resolvingFix),
annotations,
moduleInfo);
}

Expand Down Expand Up @@ -181,25 +182,31 @@ private Set<String> extractUninitializedFieldNames(String errorMessage) {
}

/**
* Generates a set of fixes for uninitialized fields from the given error message.
* Computes a set of {@link AddAnnotation} instances for fields that are uninitialized. This
* method extracts field names from the provided error message, and for each uninitialized field,
* it attempts to find the location of the field within the specified class. If a field's location
* is found, an {@link AddMarkerAnnotation} is created with the appropriate nullable annotation
* and added to the result set.
*
* @param errorMessage Given error message.
* @param region Region where the error is reported.
* @return Set of fixes for uninitialized fields to resolve the given error.
* @param errorMessage the error message containing the details about uninitialized fields.
* @param encClass The class where this error is reported.
* @param module the {@link ModuleInfo} containing the field registry and configuration
* information.
* @return an {@link ImmutableSet} of {@link AddAnnotation} instances representing the fields that
* should have annotations added, based on their uninitialized status.
*/
protected ImmutableSet<Fix> generateFixesForUninitializedFields(
String errorMessage, Region region, ModuleInfo module) {
private ImmutableSet<AddAnnotation> computeAddAnnotationInstancesForUninitializedFields(
String errorMessage, String encClass, ModuleInfo module) {
return extractUninitializedFieldNames(errorMessage).stream()
.map(
field -> {
OnField locationOnField =
module.getFieldRegistry().getLocationOnField(region.clazz, field);
module.getFieldRegistry().getLocationOnField(encClass, field);
if (locationOnField == null) {
return null;
}
return new Fix(
new AddMarkerAnnotation(
extendVariableList(locationOnField, module), config.nullableAnnot));
return new AddMarkerAnnotation(
extendVariableList(locationOnField, module), config.nullableAnnot);
})
.filter(Objects::nonNull)
.collect(ImmutableSet.toImmutableSet());
Expand Down Expand Up @@ -249,7 +256,7 @@ public void suppressRemainingErrors(AnnotationInjector injector) {
// `@Nullable` is being passed as an argument, we add a `@NullUnmarked` annotation
// to the called method.
if (error.messageType.equals("PASS_NULLABLE")
&& error.isSingleFix()
&& error.isSingleAnnotationFix()
&& error.toResolvingLocation().isOnParameter()) {
OnParameter nullableParameter = error.toResolvingParameter();
return context
Expand Down Expand Up @@ -382,7 +389,7 @@ public void preprocess(AnnotationInjector injector) {
* @param errorMessage Error Message from NullAway.
* @param region Region where the error is reported.
* @param offset Offset of program point in the source file where the error is reported.
* @param resolvingFixes Resolving fixes that can fix the error if all applied to source code.
* @param annotations Annotations that should be added source file to resolve the error.
* @param module Module where this error is reported.
* @return Creates and returns the corresponding {@link NullAwayError} instance using the provided
* information.
Expand All @@ -392,14 +399,16 @@ private NullAwayError createError(
String errorMessage,
Region region,
int offset,
ImmutableSet<Fix> resolvingFixes,
Set<AddAnnotation> annotations,
ModuleInfo module) {
// Filter fixes on elements with explicit nonnull annotations.
ImmutableSet<Fix> cleanedResolvingFixes =
resolvingFixes.stream()
.filter(f -> !module.getNonnullStore().hasExplicitNonnullAnnotation(f.toLocation()))
ImmutableSet<AddAnnotation> cleanedAnnotations =
annotations.stream()
.filter(
annot ->
!module.getNonnullStore().hasExplicitNonnullAnnotation(annot.getLocation()))
.collect(ImmutableSet.toImmutableSet());
return new NullAwayError(errorType, errorMessage, region, offset, cleanedResolvingFixes);
return new NullAwayError(errorType, errorMessage, region, offset, cleanedAnnotations);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@
import edu.ucr.cs.riple.core.registries.index.Error;
import edu.ucr.cs.riple.core.registries.index.Fix;
import edu.ucr.cs.riple.core.registries.region.Region;
import edu.ucr.cs.riple.injector.changes.AddAnnotation;
import java.util.Objects;
import java.util.Set;

/** Represents an error reported by {@link NullAway}. */
public class NullAwayError extends Error {
Expand All @@ -43,8 +45,15 @@ public NullAwayError(
String message,
Region region,
int offset,
ImmutableSet<Fix> resolvingFixes) {
super(messageType, message, region, offset, resolvingFixes);
Set<AddAnnotation> annotations) {
super(messageType, message, region, offset, annotations);
}

@Override
protected ImmutableSet<Fix> computeFixesFromAnnotations(Set<AddAnnotation> annotations) {
// In NullAway inference, each annotation is examined individually. Thus, we create a separate
// fix instance for each annotation.
return annotations.stream().map(Fix::new).collect(ImmutableSet.toImmutableSet());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import java.util.HashSet;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;

/**
* Vertex in {@link ConflictGraph} graph. It stores a fix tree (starting from a root) and all it's
Expand Down Expand Up @@ -109,11 +110,15 @@ public void reCollectPotentiallyImpactedRegions(RegionRegistry regionRegistry) {
// Add origins.
this.regions.addAll(this.origins);
this.tree.forEach(
fix -> this.regions.addAll(regionRegistry.getImpactedRegions(fix.toLocation())));
fix ->
this.regions.addAll(
fix.toLocations().stream()
.flatMap(location -> regionRegistry.getImpactedRegions(location).stream())
.collect(Collectors.toSet())));
// Add class initialization region, if a fix is modifying a parameter on constructor.
this.tree.stream()
.filter(fix -> fix.isOnParameter() && fix.isModifyingConstructor())
.forEach(fix -> regions.add(new Region(fix.change.getLocation().clazz, "null")));
.forEach(fix -> regions.add(new Region(fix.toParameter().clazz, "null")));
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,11 @@ public AbstractConflictGraphProcessor(Context context, CompilerRunner runner, Su
*/
protected Set<Fix> getTriggeredFixesFromDownstreamErrors(Node node) {
Set<Location> currentLocationsTargetedByTree =
node.tree.stream().map(Fix::toLocation).collect(Collectors.toSet());
node.tree.stream().flatMap(fix -> fix.toLocations().stream()).collect(Collectors.toSet());
return downstreamImpactCache.getTriggeredErrorsForCollection(node.tree).stream()
.filter(
error ->
error.isSingleFix()
error.isSingleAnnotationFix()
&& error.isFixableOnTarget(context)
&& !currentLocationsTargetedByTree.contains(error.toResolvingLocation()))
.map(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ public void removeFixes(Set<Fix> fixes) {
return;
}
Set<RemoveAnnotation> toRemove =
fixes.stream().map(fix -> fix.change.getReverse()).collect(Collectors.toSet());
fixes.stream()
.flatMap(fix -> fix.changes.stream().map(AddAnnotation::getReverse))
.collect(Collectors.toSet());
removeAnnotations(toRemove);
}

Expand All @@ -63,7 +65,8 @@ public void injectFixes(Set<Fix> fixes) {
if (fixes == null || fixes.size() == 0) {
return;
}
injectAnnotations(fixes.stream().map(fix -> fix.change).collect(Collectors.toSet()));
injectAnnotations(
fixes.stream().flatMap(fix -> fix.changes.stream()).collect(Collectors.toSet()));
}

/**
Expand Down
Loading

0 comments on commit 5e3437e

Please sign in to comment.