Skip to content

Commit ccadf54

Browse files
committed
Add jina clip evaluator
1 parent 20135bd commit ccadf54

File tree

2 files changed

+175
-53
lines changed

2 files changed

+175
-53
lines changed

tools/accuracy_checker/accuracy_checker/evaluators/custom_evaluators/base_custom_evaluator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,8 @@ def reset(self):
256256

257257
def release(self):
258258
self._release_model()
259-
self.launcher.release()
259+
if self.launcher:
260+
self.launcher.release()
260261

261262
def _release_model(self):
262263
if self.model:

tools/accuracy_checker/accuracy_checker/evaluators/custom_evaluators/openvino_clip_evaluator.py

Lines changed: 173 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
Copyright (c) 2024 Intel Corporation
2+
Copyright (c) 2024-2025 Intel Corporation
33
44
Licensed under the Apache License, Version 2.0 (the "License");
55
you may not use this file except in compliance with the License.
@@ -15,6 +15,8 @@
1515
"""
1616
import os
1717
import numpy as np
18+
from PIL import Image
19+
from scipy.special import softmax
1820

1921
from .base_custom_evaluator import BaseCustomEvaluator
2022
from .base_models import BaseCascadeModel
@@ -30,9 +32,19 @@
3032

3133
try:
3234
import open_clip
33-
except ImportError as error:
34-
open_clip = UnsupportedPackage('open_clip', error.msg)
35+
except ImportError as clip_error:
36+
open_clip = UnsupportedPackage('open_clip', clip_error.msg)
37+
38+
try:
39+
from transformers import AutoModel, AutoTokenizer
40+
except ImportError as transformers_error:
41+
AutoModel = UnsupportedPackage('AutoModel', transformers_error.msg)
42+
AutoTokenizer = UnsupportedPackage('AutoTokenizer', transformers_error.msg)
3543

44+
try:
45+
import torch
46+
except ImportError as torch_error:
47+
torch = UnsupportedPackage("torch", torch_error.msg)
3648

3749
class OpenVinoClipEvaluator(BaseCustomEvaluator):
3850
def __init__(self, dataset_config, launcher, model, orig_config):
@@ -42,8 +54,7 @@ def __init__(self, dataset_config, launcher, model, orig_config):
4254
@classmethod
4355
def from_configs(cls, config, delayed_model_loading=False, orig_config=None):
4456
dataset_config, launcher, _ = cls.get_dataset_and_launcher_info(config)
45-
46-
model = OpenVinoClipModel(
57+
model = OpenVinoClipVitModel(
4758
config.get('network_info', {}), launcher, config.get('_models', []),
4859
config.get('_model_is_blob'),
4960
delayed_model_loading, config
@@ -69,43 +80,52 @@ def _process(self, output_callback, calculate_metrics, progress_reporter, metric
6980
self._update_progress(progress_reporter, metric_config, batch_id, len(batch_prediction), csv_file)
7081

7182

72-
class OpenVinoClipModel(BaseCascadeModel):
83+
class OpenVinoJinaClipEvaluator(OpenVinoClipEvaluator):
84+
@classmethod
85+
def from_configs(cls, config, delayed_model_loading=False, orig_config=None):
86+
if config['launchers'][0]['framework'] == 'pytorch':
87+
dataset_config, launcher = config["datasets"], None
88+
delayed_model_loading = False
89+
else:
90+
dataset_config, launcher, _ = cls.get_dataset_and_launcher_info(config)
91+
92+
model = OpenVinoJinaClipModel(
93+
config.get('network_info', {}), launcher, config.get('_models', []),
94+
config.get('_model_is_blob'),
95+
delayed_model_loading, config
96+
)
97+
return cls(dataset_config, launcher, model, orig_config)
98+
99+
100+
class BaseOpenVinoClipModel(BaseCascadeModel):
73101
def __init__(self, network_info, launcher, models_args, is_blob, delayed_model_loading=False, config=None):
74102
super().__init__(network_info, launcher, delayed_model_loading)
75103
self.network_info = network_info
76104
self.launcher = launcher
77105
self.config = config or {}
78-
parts = ['text_encoder', 'image_encoder']
79-
network_info = self.fill_part_with_model(network_info, parts, models_args, False, delayed_model_loading)
80-
if not contains_all(network_info, parts) and not delayed_model_loading:
81-
raise ConfigError('configuration for text_encoder/image_encoder does not exist')
106+
self.templates_file = None
107+
self.parameters_file = None
108+
self.templates = ["a photo of a {classname}"]
109+
self.parts = network_info.keys()
110+
if launcher:
111+
network_info = self.fill_part_with_model(network_info, self.parts,
112+
models_args, False, delayed_model_loading)
113+
if not contains_all(network_info, self.parts) and not delayed_model_loading:
114+
raise ConfigError('configuration for text_encoder/image_encoder does not exist')
82115
if not delayed_model_loading:
83116
self.create_pipeline(launcher, network_info)
84117

85118
def create_pipeline(self, launcher, network_info):
86-
orig_model_name = self.config.get("orig_model_name", "ViT-B-16-plus-240")
87-
self.load_models(network_info, launcher, True)
88-
89-
self.text_encoder = launcher.ie_core.compile_model(self.text_encoder_model, launcher.device)
90-
self.image_encoder = launcher.ie_core.compile_model(self.image_encoder_model, launcher.device)
91-
92-
unet_shapes = [inp.get_partial_shape() for inp in self.text_encoder_model.inputs]
93-
if unet_shapes[0][0].is_dynamic:
94-
self.templates_file = self.config.get("templates", "zeroshot_classification_templates.json")
95-
else:
96-
self.templates_file = None
119+
raise NotImplementedError("Subclasses should implement this method")
97120

98-
self.classnames_file = self.config.get("classnames", "classnames.json")
99-
self.parameters_file = self.config.get("pretrained_model_params", None)
100-
self.tokenizer = open_clip.get_tokenizer(orig_model_name)
121+
def get_logits(self, image_features, zeroshot_weights):
122+
raise NotImplementedError("Subclasses should implement this method")
101123

102124
def predict(self, identifiers, input_data, zeroshot_weights):
103125
preds = []
104126
for idx, image_data in zip(identifiers, input_data):
105-
image = np.expand_dims(image_data, axis=0)
106-
image_features = self.encode_image(image)
107-
image_features = self.normalize(image_features, axis=-1)
108-
logits = 100. * image_features @ zeroshot_weights
127+
image_features = self.encode_image(image_data)
128+
logits = self.get_logits(image_features, zeroshot_weights)
109129
preds.append(ClassificationPrediction(idx, np.squeeze(logits, axis=0)))
110130
return None, preds
111131

@@ -116,23 +136,11 @@ def get_network(self):
116136
model_list.append({"name": model_part_name, "model": model})
117137
return model_list
118138

119-
def encode_image(self, image):
120-
features = self.image_encoder(image)
121-
return features[self.image_encoder.output()]
139+
def encode_image(self, image_data):
140+
raise NotImplementedError("Subclasses should implement this method")
122141

123142
def encode_text(self, texts, params):
124-
text = self.tokenizer(texts).to('cpu')
125-
indices = text.detach().cpu().numpy()
126-
127-
x = params['token_embedding'][indices]
128-
x = x + params['positional_embedding']
129-
x = x.transpose(1, 0, 2)
130-
x = self.text_encoder((x, params['attn_mask']))
131-
x = x[self.text_encoder.output()]
132-
x = x.transpose(1, 0, 2)
133-
x = self.layer_norm(x, params['gamma'], params['beta'])
134-
x = x[np.arange(x.shape[0]), np.argmax(indices, axis=-1)] @ params['text_projection']
135-
return x
143+
raise NotImplementedError("Subclasses should implement this method")
136144

137145
@staticmethod
138146
def get_pretrained_model_params(path):
@@ -147,14 +155,20 @@ def get_pretrained_model_params(path):
147155
params['beta'] = open_clip_params['beta']
148156
return params
149157

158+
def get_class_embeddings(self, texts, params):
159+
raise NotImplementedError("Subclasses should implement this method")
160+
150161
def zero_shot_classifier(self, data_source):
151162
classnames = read_json(os.path.join(data_source, self.classnames_file))
152163
if self.templates_file:
153164
templates = read_json(os.path.join(data_source, self.templates_file))
154165
else:
155-
templates = ["a photo of a {c}"]
166+
templates = self.templates
167+
168+
params = None
169+
if self.parameters_file:
170+
params = self.get_pretrained_model_params(os.path.join(data_source, self.parameters_file))
156171

157-
params = self.get_pretrained_model_params(os.path.join(data_source, self.parameters_file))
158172
print_info('Encoding zeroshot weights for {} imagenet classes'.format(len(classnames)))
159173

160174
zeroshot_weights = []
@@ -163,12 +177,9 @@ def zero_shot_classifier(self, data_source):
163177
iterator = tqdm(classnames, mininterval=2)
164178

165179
for classname in iterator:
166-
texts = [template.format(c=classname) for template in templates]
167-
class_embeddings = self.encode_text(texts, params)
168-
class_embedding = self.normalize(class_embeddings, axis=-1)
169-
class_embedding = np.mean(class_embedding, axis=0)
170-
class_embedding /= np.linalg.norm(class_embedding, ord=2)
171-
zeroshot_weights.append(class_embedding)
180+
texts = [template.format(classname=classname) for template in templates]
181+
class_embeddings = self.get_class_embeddings(texts, params)
182+
zeroshot_weights.append(class_embeddings)
172183
return np.stack(zeroshot_weights, axis=1)
173184

174185
def load_models(self, network_info, launcher, log=False):
@@ -192,7 +203,7 @@ def load_model(self, network_list, launcher):
192203
setattr(self, "{}_model".format(network_list["name"]), network)
193204

194205
def print_input_output_info(self):
195-
model_parts = ("text_encoder", "image_encoder")
206+
model_parts = self.parts
196207
for part in model_parts:
197208
part_model_id = "{}_model".format(part)
198209
model = getattr(self, part_model_id, None)
@@ -218,3 +229,113 @@ def normalize(input_array, p=2, axis=-1, epsilon=1e-12):
218229
norm = np.maximum(norm, epsilon)
219230
normalized = input_array / norm
220231
return normalized
232+
233+
234+
class OpenVinoClipVitModel(BaseOpenVinoClipModel):
235+
def create_pipeline(self, launcher, network_info):
236+
orig_model_name = self.config.get("orig_model_name", "ViT-B-16-plus-240")
237+
self.load_models(network_info, launcher, True)
238+
self.text_encoder = launcher.ie_core.compile_model(self.text_encoder_model, launcher.device)
239+
self.image_encoder = launcher.ie_core.compile_model(self.image_encoder_model, launcher.device)
240+
unet_shapes = [inp.get_partial_shape() for inp in self.text_encoder_model.inputs]
241+
if unet_shapes[0][0].is_dynamic:
242+
self.templates_file = self.config.get("templates", "zeroshot_classification_templates.json")
243+
244+
self.classnames_file = self.config.get("classnames", "classnames.json")
245+
self.parameters_file = self.config.get("pretrained_model_params", None)
246+
self.tokenizer = open_clip.get_tokenizer(orig_model_name)
247+
248+
def get_logits(self, image_features, zeroshot_weights):
249+
image_features = self.normalize(image_features, axis=-1)
250+
logits = 100. * image_features @ zeroshot_weights
251+
return logits
252+
253+
def encode_image(self, image_data):
254+
image = np.expand_dims(image_data, axis=0)
255+
features = self.image_encoder(image)
256+
return features[self.image_encoder.output()]
257+
258+
def encode_text(self, texts, params):
259+
text = self.tokenizer(texts).to('cpu')
260+
indices = text.detach().cpu().numpy()
261+
262+
x = params['token_embedding'][indices]
263+
x = x + params['positional_embedding']
264+
x = x.transpose(1, 0, 2)
265+
x = self.text_encoder((x, params['attn_mask']))
266+
x = x[self.text_encoder.output()]
267+
x = x.transpose(1, 0, 2)
268+
x = self.layer_norm(x, params['gamma'], params['beta'])
269+
x = x[np.arange(x.shape[0]), np.argmax(indices, axis=-1)] @ params['text_projection']
270+
return x
271+
272+
def get_class_embeddings(self, texts, params):
273+
class_embeddings = self.encode_text(texts, params)
274+
class_embedding = self.normalize(class_embeddings, axis=-1)
275+
class_embedding = np.mean(class_embedding, axis=0)
276+
class_embedding /= np.linalg.norm(class_embedding, ord=2)
277+
return class_embedding
278+
279+
280+
class OpenVinoJinaClipModel(BaseOpenVinoClipModel):
281+
def create_pipeline(self, launcher, network_info):
282+
if isinstance(AutoTokenizer, UnsupportedPackage):
283+
AutoTokenizer.raise_error(self.__class__.__name__)
284+
if isinstance(AutoModel, UnsupportedPackage):
285+
AutoModel.raise_error(self.__class__.__name__)
286+
if isinstance(torch, UnsupportedPackage):
287+
torch.raise_error(self.__class__.__name__)
288+
289+
orig_model_name = self.config.get("orig_model_name", "jinaai/jina-clip-v1")
290+
291+
model = AutoModel.from_pretrained(orig_model_name, trust_remote_code=True)
292+
if launcher:
293+
self.load_models(network_info, launcher, True)
294+
self.text_encoder = launcher.ie_core.compile_model(self.text_model, launcher.device)
295+
self.vision_encoder = launcher.ie_core.compile_model(self.vision_model, launcher.device)
296+
else:
297+
self.text_encoder = model.text_model
298+
self.vision_encoder = model.vision_model
299+
300+
self.templates = ["{classname}"]
301+
self.classnames_file = self.config.get("classnames", "classnames.json")
302+
self.tokenizer = AutoTokenizer.from_pretrained(orig_model_name, trust_remote_code=True)
303+
self.processor = model.get_preprocess()
304+
305+
def encode_image(self, image_data):
306+
image = Image.fromarray(image_data)
307+
vision_input = self.processor(images=[image], return_tensors="pt")
308+
image_embeddings = self.vision_encoder(vision_input["pixel_values"])
309+
310+
if isinstance(image_embeddings, torch.Tensor):
311+
image_embeddings = image_embeddings.detach().numpy()
312+
else:
313+
image_embeddings = image_embeddings[0]
314+
315+
return image_embeddings
316+
317+
def encode_text(self, text_input):
318+
text_embeddings = self.text_encoder(text_input["input_ids"])
319+
320+
if isinstance(text_embeddings, torch.Tensor):
321+
text_embeddings = text_embeddings.detach().numpy()
322+
else:
323+
text_embeddings = text_embeddings[0]
324+
return text_embeddings
325+
326+
def get_logits(self, image_features, zeroshot_weights):
327+
text_embeddings = np.squeeze(zeroshot_weights)
328+
simularity = []
329+
for emb1 in image_features:
330+
temp_simularity = []
331+
for emb2 in text_embeddings:
332+
temp_simularity.append(emb1 @ emb2)
333+
simularity.append(temp_simularity)
334+
335+
logits = 100. * softmax(simularity)
336+
return logits
337+
338+
def get_class_embeddings(self, texts, params):
339+
text_input = self.tokenizer(texts, return_tensors="pt", padding="max_length",
340+
max_length=512, truncation=True).to("cpu")
341+
return self.encode_text(text_input)

0 commit comments

Comments
 (0)