diff --git a/src/pyhf/workspace.py b/src/pyhf/workspace.py index 7ce0bc486d..b63841c02c 100644 --- a/src/pyhf/workspace.py +++ b/src/pyhf/workspace.py @@ -583,16 +583,39 @@ def _prune_and_rename( ), ) for modifier in sample['modifiers'] - if modifier['name'] not in prune_modifiers - and modifier['type'] not in prune_modifier_types + # we want to remove modifiers only if channel is not in list of channels to keep, + # we want to remove modifiers only if sample is not in list of samples to keep + if ( + prune_channels + and channel['name'] not in prune_channels + ) + or ( + prune_samples + and sample['name'] not in prune_samples + ) + or ( + modifier['name'] not in prune_modifiers + and modifier['type'] not in prune_modifier_types + ) + # need to keep the modifier in case it is used in another measurement + or prune_measurements ], } for sample in channel['samples'] - if sample['name'] not in prune_samples + # we want to remove samples only if channel is not in list of channels to keep, + # we want to remove samples only if no modifiers are to be pruned + if (prune_channels and channel['name'] not in prune_channels) + or sample['name'] not in prune_samples + or prune_modifiers + or prune_modifier_types ], } for channel in self['channels'] + # we want to remove channels only if no samples or modifiers are to be pruned if channel['name'] not in prune_channels + or prune_samples + or prune_modifiers + or prune_modifier_types ], 'measurements': [ { @@ -607,8 +630,16 @@ def _prune_and_rename( parameter['name'], parameter['name'] ), ) - for parameter in measurement['config']['parameters'] - if parameter['name'] not in prune_modifiers + for parameter in measurement['config'][ + 'parameters' + ] # we only want to remove this parameter if measurement is in prune_measurements or if prune_measurements is empty + # we want to remove parameters from a measurement only + # if measurement is not in keep_measurements + if ( + prune_measurements + and measurement['name'] not in prune_measurements + ) + or parameter['name'] not in prune_modifiers ], 'poi': rename_modifiers.get( measurement['config']['poi'], measurement['config']['poi'] @@ -616,7 +647,8 @@ def _prune_and_rename( }, } for measurement in self['measurements'] - if measurement['name'] not in prune_measurements + # we want to remove measurements only if no parameters are to be pruned + if measurement['name'] not in prune_measurements or prune_modifiers ], 'observations': [ dict( @@ -624,7 +656,12 @@ def _prune_and_rename( name=rename_channels.get(observation['name'], observation['name']), ) for observation in self['observations'] + # we want to remove this channels only + # if no samples or modifiers are to be pruned if observation['name'] not in prune_channels + or prune_samples + or prune_modifiers + or prune_modifier_types ], 'version': self['version'], } @@ -637,6 +674,7 @@ def prune( samples=None, channels=None, measurements=None, + mode="logical_or", ): """ Return a new, pruned workspace specification. This will not modify the original workspace. @@ -649,6 +687,7 @@ def prune( samples: A :obj:`list` of samples to prune. channels: A :obj:`list` of channels to prune. measurements: A :obj:`list` of measurements to prune. + mode (:obj: string): `logical_or` or `logical_and` to chain pruning with a logical OR or a logical AND, respectively. Default: `logical_or`. Returns: ~pyhf.workspace.Workspace: A new workspace object with the specified components removed @@ -657,6 +696,12 @@ def prune( ~pyhf.exceptions.InvalidWorkspaceOperation: An item name to prune does not exist in the workspace. """ + + if mode not in ["logical_and", "logical_or"]: + raise ValueError( + "Pruning mode must be either `logical_and` or `logical_or`." + ) + # avoid mutable defaults modifiers = [] if modifiers is None else modifiers modifier_types = [] if modifier_types is None else modifier_types @@ -664,12 +709,32 @@ def prune( channels = [] if channels is None else channels measurements = [] if measurements is None else measurements - return self._prune_and_rename( - prune_modifiers=modifiers, - prune_modifier_types=modifier_types, - prune_samples=samples, - prune_channels=channels, - prune_measurements=measurements, + if mode == "logical_and": + if samples != [] and measurements != []: + raise ValueError( + "Pruning of measurements and samples cannot be run with mode `logical_and`." + ) + if channels != [] and measurements != []: + raise ValueError( + "Pruning of measurements and channels cannot be run with mode `logical_and`." + ) + if modifier_types != [] and measurements != []: + raise ValueError( + "Pruning of measurements and modifier_types cannot be run with mode `logical_and`." + ) + return self._prune_and_rename( + prune_modifiers=modifiers, + prune_modifier_types=modifier_types, + prune_samples=samples, + prune_channels=channels, + prune_measurements=measurements, + ) + return ( + self._prune_and_rename(prune_modifiers=modifiers) + ._prune_and_rename(prune_modifier_types=modifier_types) + ._prune_and_rename(prune_samples=samples) + ._prune_and_rename(prune_channels=channels) + ._prune_and_rename(prune_measurements=measurements) ) def rename(self, modifiers=None, samples=None, channels=None, measurements=None):