Skip to content

Commit

Permalink
feat(bio): add --skip-vss to skip over second stage
Browse files Browse the repository at this point in the history
  • Loading branch information
danellecline committed Oct 22, 2024
1 parent 580a67f commit 809188f
Showing 1 changed file with 38 additions and 30 deletions.
68 changes: 38 additions & 30 deletions aipipeline/projects/bio/run_strided_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ def run_inference(
endpoint_url: str,
class_name: str,
version_id: int = 0,
skip_vss: bool = False,
):
"""
Run inference on a video file and queue the localizations in REDIS
Expand Down Expand Up @@ -248,36 +249,40 @@ def seconds_to_timestamp(seconds):
for loc in data:
if loc["class_name"] == class_name:
# For low confidence detections, run through the vss model
# Crop the image to the bounding box
crop_path = output_path / f"{video_path.stem}_{index}_crop.jpg"
data = {
"image_path": output_frame.as_posix(),
"crop_path": crop_path.as_posix(),
"image_width": frame_width,
"image_height": frame_height,
"x": loc["x"] / frame_width,
"y": loc["y"] / frame_height,
"xx": (loc["x"] + loc["width"]) / frame_width,
"xy": (loc["y"] + loc["height"]) / frame_height,
}
s = pd.Series(data)
crop_square_image(s, 224)
if loc["confidence"] < 0.9:
logger.info(
f"{video_path.name}: running VSS model on low confidence {class_name} detection {loc['confidence']}")
# Crop the image to the bounding box
crop_path = output_path / f"{video_path.stem}_{index}_crop.jpg"
data = {
"image_path": output_frame.as_posix(),
"crop_path": crop_path.as_posix(),
"image_width": frame_width,
"image_height": frame_height,
"x": loc["x"]/frame_width,
"y": loc["y"]/frame_height,
"xx": (loc["x"] + loc["width"]) / frame_width,
"xy": (loc["y"] + loc["height"]) / frame_height,
}
s = pd.Series(data)
crop_square_image(s, 224)
images = [read_image(crop_path.as_posix())]
file_paths, best_predictions, best_scores = run_vss(images, config_dict, top_k=3)
crop_path.unlink()
if len(best_predictions) == 0:
logger.info(f"{video_path.name}: no predictions from VSS model. Skipping this detection.")
continue
if best_predictions[0] != class_name:
if not skip_vss:
logger.info(
f"{video_path.name}: VSS model prediction {best_predictions[0]} does not match {class_name}. Skipping this detection.")
continue
logger.info(f"===>{video_path.name}: VSS model prediction {best_predictions[0]} matches {class_name}<====")
f"{video_path.name}: running VSS model on low confidence {class_name} detection {loc['confidence']}")

images = [read_image(crop_path.as_posix())]
file_paths, best_predictions, best_scores = run_vss(images, config_dict, top_k=3)
crop_path.unlink()
if len(best_predictions) == 0:
logger.info(f"{video_path.name}: no predictions from VSS model. Skipping this detection.")
continue
if best_predictions[0] != class_name:
logger.info(
f"{video_path.name}: VSS model prediction {best_predictions[0]} does not match {class_name}. Skipping this detection.")
continue
logger.info(f"===>{video_path.name}: VSS model prediction {best_predictions[0]} matches {class_name}<====")
else:
logger.info(f"{video_path.name}: {class_name} detection {loc['confidence']}")
else:
logger.info(f"====>{video_path.name}: high confidence {class_name} detection {loc['confidence']}<====")
logger.info(f"====>{video_path.name}: {class_name} detection {loc['confidence']}<====")

if not queued_video:
queued_video = True
Expand Down Expand Up @@ -355,12 +360,12 @@ def seconds_to_timestamp(seconds):
jpg_file.unlink()


def process_videos(video_files, stride, endpoint_url, config_dict, class_name, version_id):
def process_videos(video_files, stride, endpoint_url, class_name, version_id, skip_vss=False):
num_cpus = multiprocessing.cpu_count()
pool = multiprocessing.Pool(processes=num_cpus)
pool.starmap(
run_inference,
[(v, stride, endpoint_url, class_name, version_id) for v in video_files],
[(v, stride, endpoint_url, class_name, version_id, skip_vss) for v in video_files],
)
pool.close()
pool.join()
Expand Down Expand Up @@ -399,6 +404,7 @@ def parse_args():
default="http://localhost:8000/predict",
type=str,
)
parser.add_argument("--skip-vss", help="Skip running VSS model on low confidence detections.", action="store_true")
parser.add_argument("--flush", help="Flush the REDIS database.", action="store_true")
return parser.parse_args()

Expand Down Expand Up @@ -469,6 +475,7 @@ def parse_args():
args.endpoint_url,
args.class_name,
version_id,
skip_vss=args.skip_vss,
)
elif video_path.is_dir():
# Fanout to number of CPUs
Expand All @@ -478,9 +485,9 @@ def parse_args():
video_files,
args.stride,
args.endpoint_url,
config_dict,
args.class_name,
version_id,
skip_vss=args.skip_vss,
)
else:
logger.error(f"Invalid video path: {video_path}")
Expand All @@ -501,6 +508,7 @@ def parse_args():
args.endpoint_url,
args.class_name,
version_id,
skip_vss=args.skip_vss,
)

logger.info("Finished processing videos")

0 comments on commit 809188f

Please sign in to comment.