11"""
2- Copyright (c) 2024 Intel Corporation
2+ Copyright (c) 2024-2025 Intel Corporation
33
44Licensed under the Apache License, Version 2.0 (the "License");
55you may not use this file except in compliance with the License.
1515"""
1616import os
1717import numpy as np
18+ from PIL import Image
19+ from scipy .special import softmax
1820
1921from .base_custom_evaluator import BaseCustomEvaluator
2022from .base_models import BaseCascadeModel
3032
3133try :
3234 import open_clip
33- except ImportError as error :
34- open_clip = UnsupportedPackage ('open_clip' , error .msg )
35+ except ImportError as clip_error :
36+ open_clip = UnsupportedPackage ('open_clip' , clip_error .msg )
37+
38+ try :
39+ from transformers import AutoModel , AutoTokenizer
40+ except ImportError as transformers_error :
41+ AutoModel = UnsupportedPackage ('AutoModel' , transformers_error .msg )
42+ AutoTokenizer = UnsupportedPackage ('AutoTokenizer' , transformers_error .msg )
3543
44+ try :
45+ import torch
46+ except ImportError as torch_error :
47+ torch = UnsupportedPackage ("torch" , torch_error .msg )
3648
3749class OpenVinoClipEvaluator (BaseCustomEvaluator ):
3850 def __init__ (self , dataset_config , launcher , model , orig_config ):
@@ -42,8 +54,7 @@ def __init__(self, dataset_config, launcher, model, orig_config):
4254 @classmethod
4355 def from_configs (cls , config , delayed_model_loading = False , orig_config = None ):
4456 dataset_config , launcher , _ = cls .get_dataset_and_launcher_info (config )
45-
46- model = OpenVinoClipModel (
57+ model = OpenVinoClipVitModel (
4758 config .get ('network_info' , {}), launcher , config .get ('_models' , []),
4859 config .get ('_model_is_blob' ),
4960 delayed_model_loading , config
@@ -69,43 +80,52 @@ def _process(self, output_callback, calculate_metrics, progress_reporter, metric
6980 self ._update_progress (progress_reporter , metric_config , batch_id , len (batch_prediction ), csv_file )
7081
7182
72- class OpenVinoClipModel (BaseCascadeModel ):
83+ class OpenVinoJinaClipEvaluator (OpenVinoClipEvaluator ):
84+ @classmethod
85+ def from_configs (cls , config , delayed_model_loading = False , orig_config = None ):
86+ if config ['launchers' ][0 ]['framework' ] == 'pytorch' :
87+ dataset_config , launcher = config ["datasets" ], None
88+ delayed_model_loading = False
89+ else :
90+ dataset_config , launcher , _ = cls .get_dataset_and_launcher_info (config )
91+
92+ model = OpenVinoJinaClipModel (
93+ config .get ('network_info' , {}), launcher , config .get ('_models' , []),
94+ config .get ('_model_is_blob' ),
95+ delayed_model_loading , config
96+ )
97+ return cls (dataset_config , launcher , model , orig_config )
98+
99+
100+ class BaseOpenVinoClipModel (BaseCascadeModel ):
73101 def __init__ (self , network_info , launcher , models_args , is_blob , delayed_model_loading = False , config = None ):
74102 super ().__init__ (network_info , launcher , delayed_model_loading )
75103 self .network_info = network_info
76104 self .launcher = launcher
77105 self .config = config or {}
78- parts = ['text_encoder' , 'image_encoder' ]
79- network_info = self .fill_part_with_model (network_info , parts , models_args , False , delayed_model_loading )
80- if not contains_all (network_info , parts ) and not delayed_model_loading :
81- raise ConfigError ('configuration for text_encoder/image_encoder does not exist' )
106+ self .templates_file = None
107+ self .parameters_file = None
108+ self .templates = ["a photo of a {classname}" ]
109+ self .parts = network_info .keys ()
110+ if launcher :
111+ network_info = self .fill_part_with_model (network_info , self .parts ,
112+ models_args , False , delayed_model_loading )
113+ if not contains_all (network_info , self .parts ) and not delayed_model_loading :
114+ raise ConfigError ('configuration for text_encoder/image_encoder does not exist' )
82115 if not delayed_model_loading :
83116 self .create_pipeline (launcher , network_info )
84117
85118 def create_pipeline (self , launcher , network_info ):
86- orig_model_name = self .config .get ("orig_model_name" , "ViT-B-16-plus-240" )
87- self .load_models (network_info , launcher , True )
88-
89- self .text_encoder = launcher .ie_core .compile_model (self .text_encoder_model , launcher .device )
90- self .image_encoder = launcher .ie_core .compile_model (self .image_encoder_model , launcher .device )
91-
92- unet_shapes = [inp .get_partial_shape () for inp in self .text_encoder_model .inputs ]
93- if unet_shapes [0 ][0 ].is_dynamic :
94- self .templates_file = self .config .get ("templates" , "zeroshot_classification_templates.json" )
95- else :
96- self .templates_file = None
119+ raise NotImplementedError ("Subclasses should implement this method" )
97120
98- self .classnames_file = self .config .get ("classnames" , "classnames.json" )
99- self .parameters_file = self .config .get ("pretrained_model_params" , None )
100- self .tokenizer = open_clip .get_tokenizer (orig_model_name )
121+ def get_logits (self , image_features , zeroshot_weights ):
122+ raise NotImplementedError ("Subclasses should implement this method" )
101123
102124 def predict (self , identifiers , input_data , zeroshot_weights ):
103125 preds = []
104126 for idx , image_data in zip (identifiers , input_data ):
105- image = np .expand_dims (image_data , axis = 0 )
106- image_features = self .encode_image (image )
107- image_features = self .normalize (image_features , axis = - 1 )
108- logits = 100. * image_features @ zeroshot_weights
127+ image_features = self .encode_image (image_data )
128+ logits = self .get_logits (image_features , zeroshot_weights )
109129 preds .append (ClassificationPrediction (idx , np .squeeze (logits , axis = 0 )))
110130 return None , preds
111131
@@ -116,23 +136,11 @@ def get_network(self):
116136 model_list .append ({"name" : model_part_name , "model" : model })
117137 return model_list
118138
119- def encode_image (self , image ):
120- features = self .image_encoder (image )
121- return features [self .image_encoder .output ()]
139+ def encode_image (self , image_data ):
140+ raise NotImplementedError ("Subclasses should implement this method" )
122141
123142 def encode_text (self , texts , params ):
124- text = self .tokenizer (texts ).to ('cpu' )
125- indices = text .detach ().cpu ().numpy ()
126-
127- x = params ['token_embedding' ][indices ]
128- x = x + params ['positional_embedding' ]
129- x = x .transpose (1 , 0 , 2 )
130- x = self .text_encoder ((x , params ['attn_mask' ]))
131- x = x [self .text_encoder .output ()]
132- x = x .transpose (1 , 0 , 2 )
133- x = self .layer_norm (x , params ['gamma' ], params ['beta' ])
134- x = x [np .arange (x .shape [0 ]), np .argmax (indices , axis = - 1 )] @ params ['text_projection' ]
135- return x
143+ raise NotImplementedError ("Subclasses should implement this method" )
136144
137145 @staticmethod
138146 def get_pretrained_model_params (path ):
@@ -147,14 +155,20 @@ def get_pretrained_model_params(path):
147155 params ['beta' ] = open_clip_params ['beta' ]
148156 return params
149157
158+ def get_class_embeddings (self , texts , params ):
159+ raise NotImplementedError ("Subclasses should implement this method" )
160+
150161 def zero_shot_classifier (self , data_source ):
151162 classnames = read_json (os .path .join (data_source , self .classnames_file ))
152163 if self .templates_file :
153164 templates = read_json (os .path .join (data_source , self .templates_file ))
154165 else :
155- templates = ["a photo of a {c}" ]
166+ templates = self .templates
167+
168+ params = None
169+ if self .parameters_file :
170+ params = self .get_pretrained_model_params (os .path .join (data_source , self .parameters_file ))
156171
157- params = self .get_pretrained_model_params (os .path .join (data_source , self .parameters_file ))
158172 print_info ('Encoding zeroshot weights for {} imagenet classes' .format (len (classnames )))
159173
160174 zeroshot_weights = []
@@ -163,12 +177,9 @@ def zero_shot_classifier(self, data_source):
163177 iterator = tqdm (classnames , mininterval = 2 )
164178
165179 for classname in iterator :
166- texts = [template .format (c = classname ) for template in templates ]
167- class_embeddings = self .encode_text (texts , params )
168- class_embedding = self .normalize (class_embeddings , axis = - 1 )
169- class_embedding = np .mean (class_embedding , axis = 0 )
170- class_embedding /= np .linalg .norm (class_embedding , ord = 2 )
171- zeroshot_weights .append (class_embedding )
180+ texts = [template .format (classname = classname ) for template in templates ]
181+ class_embeddings = self .get_class_embeddings (texts , params )
182+ zeroshot_weights .append (class_embeddings )
172183 return np .stack (zeroshot_weights , axis = 1 )
173184
174185 def load_models (self , network_info , launcher , log = False ):
@@ -192,7 +203,7 @@ def load_model(self, network_list, launcher):
192203 setattr (self , "{}_model" .format (network_list ["name" ]), network )
193204
194205 def print_input_output_info (self ):
195- model_parts = ( "text_encoder" , "image_encoder" )
206+ model_parts = self . parts
196207 for part in model_parts :
197208 part_model_id = "{}_model" .format (part )
198209 model = getattr (self , part_model_id , None )
@@ -218,3 +229,113 @@ def normalize(input_array, p=2, axis=-1, epsilon=1e-12):
218229 norm = np .maximum (norm , epsilon )
219230 normalized = input_array / norm
220231 return normalized
232+
233+
234+ class OpenVinoClipVitModel (BaseOpenVinoClipModel ):
235+ def create_pipeline (self , launcher , network_info ):
236+ orig_model_name = self .config .get ("orig_model_name" , "ViT-B-16-plus-240" )
237+ self .load_models (network_info , launcher , True )
238+ self .text_encoder = launcher .ie_core .compile_model (self .text_encoder_model , launcher .device )
239+ self .image_encoder = launcher .ie_core .compile_model (self .image_encoder_model , launcher .device )
240+ unet_shapes = [inp .get_partial_shape () for inp in self .text_encoder_model .inputs ]
241+ if unet_shapes [0 ][0 ].is_dynamic :
242+ self .templates_file = self .config .get ("templates" , "zeroshot_classification_templates.json" )
243+
244+ self .classnames_file = self .config .get ("classnames" , "classnames.json" )
245+ self .parameters_file = self .config .get ("pretrained_model_params" , None )
246+ self .tokenizer = open_clip .get_tokenizer (orig_model_name )
247+
248+ def get_logits (self , image_features , zeroshot_weights ):
249+ image_features = self .normalize (image_features , axis = - 1 )
250+ logits = 100. * image_features @ zeroshot_weights
251+ return logits
252+
253+ def encode_image (self , image_data ):
254+ image = np .expand_dims (image_data , axis = 0 )
255+ features = self .image_encoder (image )
256+ return features [self .image_encoder .output ()]
257+
258+ def encode_text (self , texts , params ):
259+ text = self .tokenizer (texts ).to ('cpu' )
260+ indices = text .detach ().cpu ().numpy ()
261+
262+ x = params ['token_embedding' ][indices ]
263+ x = x + params ['positional_embedding' ]
264+ x = x .transpose (1 , 0 , 2 )
265+ x = self .text_encoder ((x , params ['attn_mask' ]))
266+ x = x [self .text_encoder .output ()]
267+ x = x .transpose (1 , 0 , 2 )
268+ x = self .layer_norm (x , params ['gamma' ], params ['beta' ])
269+ x = x [np .arange (x .shape [0 ]), np .argmax (indices , axis = - 1 )] @ params ['text_projection' ]
270+ return x
271+
272+ def get_class_embeddings (self , texts , params ):
273+ class_embeddings = self .encode_text (texts , params )
274+ class_embedding = self .normalize (class_embeddings , axis = - 1 )
275+ class_embedding = np .mean (class_embedding , axis = 0 )
276+ class_embedding /= np .linalg .norm (class_embedding , ord = 2 )
277+ return class_embedding
278+
279+
280+ class OpenVinoJinaClipModel (BaseOpenVinoClipModel ):
281+ def create_pipeline (self , launcher , network_info ):
282+ if isinstance (AutoTokenizer , UnsupportedPackage ):
283+ AutoTokenizer .raise_error (self .__class__ .__name__ )
284+ if isinstance (AutoModel , UnsupportedPackage ):
285+ AutoModel .raise_error (self .__class__ .__name__ )
286+ if isinstance (torch , UnsupportedPackage ):
287+ torch .raise_error (self .__class__ .__name__ )
288+
289+ orig_model_name = self .config .get ("orig_model_name" , "jinaai/jina-clip-v1" )
290+
291+ model = AutoModel .from_pretrained (orig_model_name , trust_remote_code = True )
292+ if launcher :
293+ self .load_models (network_info , launcher , True )
294+ self .text_encoder = launcher .ie_core .compile_model (self .text_model , launcher .device )
295+ self .vision_encoder = launcher .ie_core .compile_model (self .vision_model , launcher .device )
296+ else :
297+ self .text_encoder = model .text_model
298+ self .vision_encoder = model .vision_model
299+
300+ self .templates = ["{classname}" ]
301+ self .classnames_file = self .config .get ("classnames" , "classnames.json" )
302+ self .tokenizer = AutoTokenizer .from_pretrained (orig_model_name , trust_remote_code = True )
303+ self .processor = model .get_preprocess ()
304+
305+ def encode_image (self , image_data ):
306+ image = Image .fromarray (image_data )
307+ vision_input = self .processor (images = [image ], return_tensors = "pt" )
308+ image_embeddings = self .vision_encoder (vision_input ["pixel_values" ])
309+
310+ if isinstance (image_embeddings , torch .Tensor ):
311+ image_embeddings = image_embeddings .detach ().numpy ()
312+ else :
313+ image_embeddings = image_embeddings [0 ]
314+
315+ return image_embeddings
316+
317+ def encode_text (self , text_input ):
318+ text_embeddings = self .text_encoder (text_input ["input_ids" ])
319+
320+ if isinstance (text_embeddings , torch .Tensor ):
321+ text_embeddings = text_embeddings .detach ().numpy ()
322+ else :
323+ text_embeddings = text_embeddings [0 ]
324+ return text_embeddings
325+
326+ def get_logits (self , image_features , zeroshot_weights ):
327+ text_embeddings = np .squeeze (zeroshot_weights )
328+ simularity = []
329+ for emb1 in image_features :
330+ temp_simularity = []
331+ for emb2 in text_embeddings :
332+ temp_simularity .append (emb1 @ emb2 )
333+ simularity .append (temp_simularity )
334+
335+ logits = 100. * softmax (simularity )
336+ return logits
337+
338+ def get_class_embeddings (self , texts , params ):
339+ text_input = self .tokenizer (texts , return_tensors = "pt" , padding = "max_length" ,
340+ max_length = 512 , truncation = True ).to ("cpu" )
341+ return self .encode_text (text_input )
0 commit comments