Skip to content

Commit c9711b8

Browse files
committed
Add test for DETR BetterTransformer
Signed-off-by: Issam Arabi <issam@cs.toronto.edu>
1 parent 937bd99 commit c9711b8

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

tests/bettertransformer/test_vision.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ class BetterTransformersVisionTest(BetterTransformersTestMixin, unittest.TestCas
2727
r"""
2828
Testing suite for Vision Models - tests all the tests defined in `BetterTransformersTestMixin`
2929
"""
30-
SUPPORTED_ARCH = ["blip-2", "clip", "clip_text_model", "deit", "vilt", "vit", "vit_mae", "vit_msn", "yolos"]
30+
SUPPORTED_ARCH = ["blip-2", "clip", "clip_text_model", "deit", "detr", "vilt", "vit", "vit_mae", "vit_msn", "yolos"]
3131

3232
def prepare_inputs_for_class(self, model_id, model_type, batch_size=3, **preprocessor_kwargs):
3333
if model_type == "vilt":
@@ -56,6 +56,14 @@ def prepare_inputs_for_class(self, model_id, model_type, batch_size=3, **preproc
5656

5757
if model_type == "blip-2":
5858
inputs["decoder_input_ids"] = inputs["input_ids"]
59+
60+
elif model_type == "detr":
61+
# Assuming detr just needs an image
62+
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
63+
image = Image.open(requests.get(url, stream=True).raw)
64+
65+
feature_extractor = AutoFeatureExtractor.from_pretrained("hf-internal-testing/tiny-random-DetrModel")
66+
inputs = feature_extractor(images=image, return_tensors="pt")
5967

6068
else:
6169
url = "http://images.cocodataset.org/val2017/000000039769.jpg"

0 commit comments

Comments
 (0)