Skip to content

Commit

Permalink
perf: load all detections if less than 100 total examples
Browse files Browse the repository at this point in the history
  • Loading branch information
danellecline committed Oct 23, 2024
1 parent 25bd19f commit 7ec6053
Showing 1 changed file with 20 additions and 14 deletions.
34 changes: 20 additions & 14 deletions aipipeline/prediction/vss_init_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,24 +58,30 @@ def load_exemplars(data, config_dict=Dict, conf_files=Dict) -> str:
all_exemplars = list(Path(save_dir).rglob("*_exemplars.csv"))
exemplar_file = sorted(all_exemplars, key=os.path.getmtime, reverse=True)[0] if all_exemplars else None

# Grab the most recent detections file
all_detections = list(Path(save_dir).rglob("*cluster_detections.csv"))
detection_file = sorted(all_detections, key=os.path.getmtime, reverse=True)[0] if all_detections else None

if exemplar_file is None:
logger.info(f"No exemplar file found for {label}")
exemplar_count = 0
else:
with open(exemplar_file, "r") as f:
exemplar_count = len(f.readlines())
if detection_file is None:
logger.info(f"No detection file found for {label}")
detection_count = 0
else:
with open(detection_file, "r") as f:
detection_count = len(f.readlines())

if exemplar_count < 10 or exemplar_file is None:
all_detections = list(Path(save_dir).rglob("*_detections.csv"))
exemplar_file = sorted(all_detections, key=os.path.getmtime, reverse=True)[0] if all_detections else None

if exemplar_file is None:
logger.info(f"No detections file found for {label}")
return f"No exemplar or detections file found for {label}"
if exemplar_count == 0 or detection_count == 0:
logger.info(f"No exemplars or detections found for {label}")
continue

with open(exemplar_file, "r") as f:
exemplar_count = len(f.readlines())
logger.info(f"To few exemplars, using detections file {exemplar_file} instead")
if exemplar_count < 10 or detection_count < 100:
logger.info(f"Too few exemplars, using detections file {detection_file} instead")
exemplar_file = detection_file

logger.info(f"Loading {exemplar_count} exemplars for {label} as {label} from {exemplar_file}")
args = [
Expand Down Expand Up @@ -106,9 +112,9 @@ def load_exemplars(data, config_dict=Dict, conf_files=Dict) -> str:
bind_volumes=dict(config_dict["docker"]["bind_volumes"]),
)
if container:
logger.info(f"Loading cluster exemplars for {label}...")
logger.info(f"Loading cluster exemplars for {label} from {exemplar_file}...")
container.wait()
logger.info(f"Loaded cluster exemplars for {label}")
logger.info(f"Loaded cluster exemplars for {label} from {exemplar_file}")
num_loaded += 1
break
else:
Expand All @@ -134,8 +140,8 @@ def run_pipeline(argv=None):
parser.add_argument("--config", required=True, help=f"Config file path, e.g. {example_project}")
parser.add_argument("--skip-clean", required=False, default=False, help="Skip cleaning of previously downloaded data")
parser.add_argument("--skip-download", required=False, default=False, help="Skip downloading data")
parser.add_argument("--batch-size", required=False, default=3, help="Batch size")
parser.add_argument("--min-variance", required=False, default=3.0, help="Minimum variance for blurriness")
parser.add_argument("--batch-size", required=False, type=int, default=3, help="Batch size")
parser.add_argument("--min-variance", required=False, type=float, default=3.0, help="Minimum variance for blurriness")
args, beam_args = parser.parse_known_args(argv)

conf_files, config_dict = setup_config(args.config)
Expand Down

0 comments on commit 7ec6053

Please sign in to comment.