Skip to content

Commit

Permalink
fix(embeddings): skip broken files (#111)
Browse files Browse the repository at this point in the history
  • Loading branch information
Encord-davids authored Jan 23, 2023
1 parent 854d113 commit cfcf4b0
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 17 deletions.
39 changes: 24 additions & 15 deletions src/encord_active/lib/embeddings/cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pathlib import Path
from typing import List, Optional, Tuple

import numpy as np
import torch
import torchvision.transforms as torch_transforms
from encord.objects.common import PropertyType
Expand Down Expand Up @@ -94,13 +95,10 @@ def generate_cnn_image_embeddings(iterator: Iterator) -> List[LabelEmbedding]:

collections: List[LabelEmbedding] = []
for data_unit, img_pth in iterator.iterate(desc="Embedding image data."):
if img_pth is None:
continue
embedding = get_embdding_for_image(feature_extractor, transforms, img_pth)

image = image_path_to_tensor(img_pth)
transformed_image = transforms(image).unsqueeze(0)
embedding = feature_extractor(transformed_image.to(DEVICE))["my_avgpool"]
embedding = torch.flatten(embedding).cpu().detach().numpy()
if embedding is None:
continue

entry = LabelEmbedding(
url=data_unit["data_link"],
Expand Down Expand Up @@ -192,9 +190,6 @@ def generate_cnn_classification_embeddings(iterator: Iterator) -> List[LabelEmbe

collections = []
for data_unit, img_pth in iterator.iterate(desc="Embedding classification data."):
if not img_pth:
continue

matching_image_collections = [
collection
for collection in image_collections
Expand All @@ -203,13 +198,13 @@ def generate_cnn_classification_embeddings(iterator: Iterator) -> List[LabelEmbe
and collection["frame"] == iterator.frame
]

if len(image_collections):
embedding = matching_image_collections[0]["embedding"]
if not len(image_collections):
embedding = get_embdding_for_image(feature_extractor, transforms, img_pth)
else:
image = image_path_to_tensor(img_pth)
transformed_image = transforms(image).unsqueeze(0)
embedding = feature_extractor(transformed_image.to(DEVICE))["my_avgpool"]
embedding = torch.flatten(embedding).cpu().detach().numpy() # type: ignore
embedding = matching_image_collections[0]["embedding"]

if embedding is None:
continue

classification_answers = iterator.label_rows[iterator.label_hash]["classification_answers"]
for classification in data_unit["labels"].get("classifications", []):
Expand Down Expand Up @@ -304,3 +299,17 @@ def generate_cnn_embeddings(iterator: Iterator, embedding_type: EmbeddingType, t
logger.info("Done!")

return cnn_embeddings


def get_embdding_for_image(feature_extractor, transforms, img_pth: Optional[Path] = None) -> Optional[np.ndarray]:
if img_pth is None:
return None

try:
image = image_path_to_tensor(img_pth)
transformed_image = transforms(image).unsqueeze(0)
embedding = feature_extractor(transformed_image.to(DEVICE))["my_avgpool"]
return torch.flatten(embedding).cpu().detach().numpy()
except:
logger.error(f"Falied generating embedding for file: {img_pth}")
return None
7 changes: 5 additions & 2 deletions src/encord_active/lib/project/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,11 @@ def download_images_from_data_unit(lr, project_file_structure: ProjectFileStruct
for du in data_units:
suffix = f".{du['data_type'].split('/')[1]}"
out_pth = (lr_structure.images_dir / du["data_hash"]).with_suffix(suffix)
out_pth = download_file(du["data_link"], out_pth)
frame_pths.append(out_pth)
try:
out_pth = download_file(du["data_link"], out_pth)
frame_pths.append(out_pth)
except:
logger.warning(f"Could not download data unit <blue>`{du['data_hash']}`</blue>, skipping...")

if lr.data_type == "video":
video_path = frame_pths[0]
Expand Down

0 comments on commit cfcf4b0

Please sign in to comment.