Skip to content

Commit 13d069d

Browse files
authored
Add mean_pooling parameter to reidentification.py
Average the embeddings of all tokens for last_hidden_state
1 parent 3b53c42 commit 13d069d

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

tools/accuracy_checker/accuracy_checker/adapters/reidentification.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
Copyright (c) 2018-2024 Intel Corporation
2+
Copyright (c) 2018-2025 Intel Corporation
33
44
Licensed under the Apache License, Version 2.0 (the "License");
55
you may not use this file except in compliance with the License.
@@ -41,7 +41,9 @@ def parameters(cls):
4141
choices=['sum', 'concatenation']
4242
),
4343
'target_out': StringField(optional=True, description='Target output layer name'),
44-
'keep_shape': BoolField(optional=True, default=False, description='keep output embedding shape')
44+
'keep_shape': BoolField(optional=True, default=False, description='keep output embedding shape'),
45+
'mean_pooling': BoolField(optional=True, default=False,
46+
description='Average the embeddings of all tokens for last_hidden_state')
4547
})
4648

4749
return parameters
@@ -54,6 +56,7 @@ def configure(self):
5456
self.joining_method = self.get_value_from_config('joining_method')
5557
self.target_out = self.get_value_from_config('target_out')
5658
self.keep_shape = self.get_value_from_config('keep_shape')
59+
self.mean_pooling = self.get_value_from_config('mean_pooling')
5760

5861
def process(self, raw, identifiers, frame_meta):
5962
"""
@@ -67,6 +70,10 @@ def process(self, raw, identifiers, frame_meta):
6770
raw_prediction = self._extract_predictions(raw, frame_meta)
6871
prediction = raw_prediction[self.output_blob]
6972

73+
if self.mean_pooling:
74+
# Shape: (1, 128, 768) -> (1, 768)
75+
prediction = np.mean(prediction, axis=1)
76+
7077
if self.grn_workaround:
7178
# workaround: GRN layer
7279
prediction = self._grn_layer(prediction)

0 commit comments

Comments
 (0)