From be29708d34d30202a0bd31ae2d7f73fa1e9781bb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eloy=20P=C3=A9rez=20Torres?= Date: Wed, 28 Feb 2024 16:22:12 +0100 Subject: [PATCH 1/3] fix: add current device to the OpenCLIP model --- clip_eval/models/CLIP_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/clip_eval/models/CLIP_model.py b/clip_eval/models/CLIP_model.py index 6d14ea8..221e1a0 100644 --- a/clip_eval/models/CLIP_model.py +++ b/clip_eval/models/CLIP_model.py @@ -155,7 +155,7 @@ def _setup(self, **kwargs) -> None: model, _, preprocess = open_clip.create_model_and_transforms( model_name=self.model_name, pretrained=self.pretrained, **kwargs ) - self.model = model + self.model = model.to(self.device) self.processor = preprocess def build_embedding(self, dataloader: DataLoader): From 7c361069a7635f4af4f62e3c9408ee6ff4d0c0be Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eloy=20P=C3=A9rez=20Torres?= Date: Wed, 28 Feb 2024 16:23:46 +0100 Subject: [PATCH 2/3] misc: update class naming to camel case --- clip_eval/models/CLIP_model.py | 8 ++++++-- clip_eval/models/provider.py | 24 ++++++++++++------------ 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/clip_eval/models/CLIP_model.py b/clip_eval/models/CLIP_model.py index 221e1a0..3bd687f 100644 --- a/clip_eval/models/CLIP_model.py +++ b/clip_eval/models/CLIP_model.py @@ -67,7 +67,7 @@ def _check_device(device: str): raise ValueError(f"Unavailable device: {device}") -class closed_CLIPModel(CLIPModel): +class ClosedCLIPModel(CLIPModel): def __init__(self, title: str, title_in_source: str, device: str | None = None) -> None: super().__init__(title, title_in_source, device) @@ -115,7 +115,11 @@ def build_embedding(self, dataloader: DataLoader) -> tuple[EmbeddingArray, Class return image_embeddings, labels -class open_CLIPModel(CLIPModel): +class sds_A: + pass + + +class OpenCLIPModel(CLIPModel): def __init__( self, title: str, diff --git a/clip_eval/models/provider.py b/clip_eval/models/provider.py index 33d1055..b8ed117 100644 --- a/clip_eval/models/provider.py +++ b/clip_eval/models/provider.py @@ -1,4 +1,4 @@ -from .CLIP_model import CLIPModel, SiglipModel, closed_CLIPModel, open_CLIPModel +from .CLIP_model import CLIPModel, ClosedCLIPModel, OpenCLIPModel, SiglipModel class ModelProvider: @@ -19,24 +19,24 @@ def list_model_names(self) -> list[str]: model_provider = ModelProvider() -model_provider.register_model("clip", closed_CLIPModel, title_in_source="openai/clip-vit-large-patch14-336") -model_provider.register_model("plip", closed_CLIPModel, title_in_source="vinid/plip") +model_provider.register_model("clip", ClosedCLIPModel, title_in_source="openai/clip-vit-large-patch14-336") +model_provider.register_model("plip", ClosedCLIPModel, title_in_source="vinid/plip") model_provider.register_model( "pubmed", - closed_CLIPModel, + ClosedCLIPModel, title_in_source="flaviagiammarino/pubmed-clip-vit-base-patch32", ) -model_provider.register_model("bioclip", closed_CLIPModel, title_in_source="imageomics/bioclip") +model_provider.register_model("bioclip", ClosedCLIPModel, title_in_source="imageomics/bioclip") model_provider.register_model( "tinyclip", - closed_CLIPModel, + ClosedCLIPModel, title_in_source="wkcn/TinyCLIP-ViT-40M-32-Text-19M-LAION400M", ) -model_provider.register_model("fashion", closed_CLIPModel, title_in_source="patrickjohncyh/fashion-clip") -model_provider.register_model("rscid", closed_CLIPModel, title_in_source="flax-community/clip-rsicd") -model_provider.register_model("street", closed_CLIPModel, title_in_source="geolocal/StreetCLIP") -model_provider.register_model("apple", open_CLIPModel, title_in_source="apple/DFN5B-CLIP-ViT-H-14") -model_provider.register_model("eva-clip", open_CLIPModel, title_in_source="BAAI/EVA-CLIP-8B-448") -model_provider.register_model("vit-b-32-laion2b", open_CLIPModel, model_name="ViT-B-32", pretrained="laion2b_e16") +model_provider.register_model("fashion", ClosedCLIPModel, title_in_source="patrickjohncyh/fashion-clip") +model_provider.register_model("rscid", ClosedCLIPModel, title_in_source="flax-community/clip-rsicd") +model_provider.register_model("street", ClosedCLIPModel, title_in_source="geolocal/StreetCLIP") +model_provider.register_model("apple", OpenCLIPModel, title_in_source="apple/DFN5B-CLIP-ViT-H-14") +model_provider.register_model("eva-clip", OpenCLIPModel, title_in_source="BAAI/EVA-CLIP-8B-448") +model_provider.register_model("vit-b-32-laion2b", OpenCLIPModel, model_name="ViT-B-32", pretrained="laion2b_e16") model_provider.register_model("siglip_small", SiglipModel, title_in_source="google/siglip-base-patch16-224") model_provider.register_model("siglip_large", SiglipModel, title_in_source="google/siglip-large-patch16-256") From 942289c92f7da0fa9de878faf8f058c95b40ef8a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eloy=20P=C3=A9rez=20Torres?= Date: Wed, 28 Feb 2024 17:08:12 +0100 Subject: [PATCH 3/3] fix: remove test code --- clip_eval/models/CLIP_model.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/clip_eval/models/CLIP_model.py b/clip_eval/models/CLIP_model.py index 3bd687f..2214009 100644 --- a/clip_eval/models/CLIP_model.py +++ b/clip_eval/models/CLIP_model.py @@ -115,10 +115,6 @@ def build_embedding(self, dataloader: DataLoader) -> tuple[EmbeddingArray, Class return image_embeddings, labels -class sds_A: - pass - - class OpenCLIPModel(CLIPModel): def __init__( self,