11"""
2- Copyright (c) 2018-2024 Intel Corporation
2+ Copyright (c) 2018-2025 Intel Corporation
33
44Licensed under the Apache License, Version 2.0 (the "License");
55you may not use this file except in compliance with the License.
@@ -41,7 +41,9 @@ def parameters(cls):
4141 choices = ['sum' , 'concatenation' ]
4242 ),
4343 'target_out' : StringField (optional = True , description = 'Target output layer name' ),
44- 'keep_shape' : BoolField (optional = True , default = False , description = 'keep output embedding shape' )
44+ 'keep_shape' : BoolField (optional = True , default = False , description = 'keep output embedding shape' ),
45+ 'mean_pooling' : BoolField (optional = True , default = False ,
46+ description = 'Average the embeddings of all tokens for last_hidden_state' )
4547 })
4648
4749 return parameters
@@ -54,6 +56,7 @@ def configure(self):
5456 self .joining_method = self .get_value_from_config ('joining_method' )
5557 self .target_out = self .get_value_from_config ('target_out' )
5658 self .keep_shape = self .get_value_from_config ('keep_shape' )
59+ self .mean_pooling = self .get_value_from_config ('mean_pooling' )
5760
5861 def process (self , raw , identifiers , frame_meta ):
5962 """
@@ -67,6 +70,10 @@ def process(self, raw, identifiers, frame_meta):
6770 raw_prediction = self ._extract_predictions (raw , frame_meta )
6871 prediction = raw_prediction [self .output_blob ]
6972
73+ if self .mean_pooling :
74+ # Shape: (1, 128, 768) -> (1, 768)
75+ prediction = np .mean (prediction , axis = 1 )
76+
7077 if self .grn_workaround :
7178 # workaround: GRN layer
7279 prediction = self ._grn_layer (prediction )
0 commit comments