Skip to content

Commit 48c41c4

Browse files
committed
Use torch version of softmax
1 parent ccadf54 commit 48c41c4

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

tools/accuracy_checker/accuracy_checker/evaluators/custom_evaluators/openvino_clip_evaluator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@
1616
import os
1717
import numpy as np
1818
from PIL import Image
19-
from scipy.special import softmax
20-
2119
from .base_custom_evaluator import BaseCustomEvaluator
2220
from .base_models import BaseCascadeModel
2321
from ...config import ConfigError
@@ -43,6 +41,7 @@
4341

4442
try:
4543
import torch
44+
import torch.nn.functional as F
4645
except ImportError as torch_error:
4746
torch = UnsupportedPackage("torch", torch_error.msg)
4847

@@ -332,7 +331,8 @@ def get_logits(self, image_features, zeroshot_weights):
332331
temp_simularity.append(emb1 @ emb2)
333332
simularity.append(temp_simularity)
334333

335-
logits = 100. * softmax(simularity)
334+
simularity_tensor = torch.tensor(simularity)
335+
logits = 100. * F.softmax(simularity_tensor, dim=-1).numpy()
336336
return logits
337337

338338
def get_class_embeddings(self, texts, params):

0 commit comments

Comments
 (0)