From cfcf4b0c7e66cf4379c8689955537643d9392483 Mon Sep 17 00:00:00 2001 From: David Sapiro <115489098+Encord-davids@users.noreply.github.com> Date: Mon, 23 Jan 2023 13:46:14 +0000 Subject: [PATCH] fix(embeddings): skip broken files (#111) --- src/encord_active/lib/embeddings/cnn.py | 39 +++++++++++++++--------- src/encord_active/lib/project/project.py | 7 +++-- 2 files changed, 29 insertions(+), 17 deletions(-) diff --git a/src/encord_active/lib/embeddings/cnn.py b/src/encord_active/lib/embeddings/cnn.py index 1225d67bd..09a7fd580 100644 --- a/src/encord_active/lib/embeddings/cnn.py +++ b/src/encord_active/lib/embeddings/cnn.py @@ -5,6 +5,7 @@ from pathlib import Path from typing import List, Optional, Tuple +import numpy as np import torch import torchvision.transforms as torch_transforms from encord.objects.common import PropertyType @@ -94,13 +95,10 @@ def generate_cnn_image_embeddings(iterator: Iterator) -> List[LabelEmbedding]: collections: List[LabelEmbedding] = [] for data_unit, img_pth in iterator.iterate(desc="Embedding image data."): - if img_pth is None: - continue + embedding = get_embdding_for_image(feature_extractor, transforms, img_pth) - image = image_path_to_tensor(img_pth) - transformed_image = transforms(image).unsqueeze(0) - embedding = feature_extractor(transformed_image.to(DEVICE))["my_avgpool"] - embedding = torch.flatten(embedding).cpu().detach().numpy() + if embedding is None: + continue entry = LabelEmbedding( url=data_unit["data_link"], @@ -192,9 +190,6 @@ def generate_cnn_classification_embeddings(iterator: Iterator) -> List[LabelEmbe collections = [] for data_unit, img_pth in iterator.iterate(desc="Embedding classification data."): - if not img_pth: - continue - matching_image_collections = [ collection for collection in image_collections @@ -203,13 +198,13 @@ def generate_cnn_classification_embeddings(iterator: Iterator) -> List[LabelEmbe and collection["frame"] == iterator.frame ] - if len(image_collections): - embedding = matching_image_collections[0]["embedding"] + if not len(image_collections): + embedding = get_embdding_for_image(feature_extractor, transforms, img_pth) else: - image = image_path_to_tensor(img_pth) - transformed_image = transforms(image).unsqueeze(0) - embedding = feature_extractor(transformed_image.to(DEVICE))["my_avgpool"] - embedding = torch.flatten(embedding).cpu().detach().numpy() # type: ignore + embedding = matching_image_collections[0]["embedding"] + + if embedding is None: + continue classification_answers = iterator.label_rows[iterator.label_hash]["classification_answers"] for classification in data_unit["labels"].get("classifications", []): @@ -304,3 +299,17 @@ def generate_cnn_embeddings(iterator: Iterator, embedding_type: EmbeddingType, t logger.info("Done!") return cnn_embeddings + + +def get_embdding_for_image(feature_extractor, transforms, img_pth: Optional[Path] = None) -> Optional[np.ndarray]: + if img_pth is None: + return None + + try: + image = image_path_to_tensor(img_pth) + transformed_image = transforms(image).unsqueeze(0) + embedding = feature_extractor(transformed_image.to(DEVICE))["my_avgpool"] + return torch.flatten(embedding).cpu().detach().numpy() + except: + logger.error(f"Falied generating embedding for file: {img_pth}") + return None diff --git a/src/encord_active/lib/project/project.py b/src/encord_active/lib/project/project.py index 21b524100..c7a8201f6 100644 --- a/src/encord_active/lib/project/project.py +++ b/src/encord_active/lib/project/project.py @@ -217,8 +217,11 @@ def download_images_from_data_unit(lr, project_file_structure: ProjectFileStruct for du in data_units: suffix = f".{du['data_type'].split('/')[1]}" out_pth = (lr_structure.images_dir / du["data_hash"]).with_suffix(suffix) - out_pth = download_file(du["data_link"], out_pth) - frame_pths.append(out_pth) + try: + out_pth = download_file(du["data_link"], out_pth) + frame_pths.append(out_pth) + except: + logger.warning(f"Could not download data unit `{du['data_hash']}`, skipping...") if lr.data_type == "video": video_path = frame_pths[0]