Skip to content

Commit 9c9e015

Browse files
authored
Fixes #4182: The huggingface examples return strange results (#4190)
* Fixes #4182: The huggingface examples return strange results * changes review and removed unused imports
1 parent 487d970 commit 9c9e015

File tree

3 files changed

+13
-9
lines changed

3 files changed

+13
-9
lines changed

docs/asciidoc/modules/ROOT/pages/ml/openai.adoc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,10 +161,12 @@ For the https://huggingface.co/[HuggingFace API], we have to define the config `
161161
For example:
162162
[source,cypher]
163163
----
164-
CALL apoc.ml.openai.completion('What color is the sky? Answer in one word: ', $huggingFaceApiKey,
165-
{endpoint: 'https://api-inference.huggingface.co/models/gpt2', apiType: 'HUGGINGFACE', model: 'gpt2', path: ''})
164+
CALL apoc.ml.openai.completion('[MASK] is the color of the sky', $huggingFaceApiKey,
165+
{endpoint: 'https://api-inference.huggingface.co/models/google-bert/bert-base-uncased', apiType: 'HUGGINGFACE'})
166166
----
167167

168+
With gpt2 or other text completion models the answers are not valid.
169+
168170
Or also, by using the https://docs.cohere.com/docs[Cohere API], where we have to define `path: '''` not to add the `/completions` suffix to the URL:
169171
[source,cypher]
170172
----

extended/src/main/java/apoc/ml/OpenAI.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import static apoc.ExtendedApocConfig.APOC_ML_OPENAI_TYPE;
2626
import static apoc.ExtendedApocConfig.APOC_OPENAI_KEY;
2727
import static apoc.ml.MLUtil.*;
28+
import static apoc.ml.RestAPIConfig.METHOD_KEY;
2829

2930

3031
@Extended
@@ -103,6 +104,8 @@ private static void handleAPIProvider(OpenAIRequestHandler.Type type,
103104
}
104105
case HUGGINGFACE -> {
105106
configForPayload.putIfAbsent("inputs", inputs);
107+
configuration.putIfAbsent(PATH_CONF_KEY, "");
108+
headers.putIfAbsent(METHOD_KEY, "POST");
106109
configuration.putIfAbsent(JSON_PATH_CONF_KEY, "$[0]");
107110
}
108111
case ANTHROPIC -> {

extended/src/test/java/apoc/ml/OpenAIOpenLMIT.java

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -44,18 +44,17 @@ public void setUp() throws Exception {
4444
public void completionWithHuggingFace() {
4545
String huggingFaceApiKey = System.getenv("HF_API_TOKEN");
4646
Assume.assumeNotNull("No HF_API_TOKEN environment configured", huggingFaceApiKey);
47-
48-
String modelId = "gpt2";
47+
48+
String modelId = "google-bert/bert-base-uncased";
4949
Map<String, String> conf = Map.of(ENDPOINT_CONF_KEY, "https://api-inference.huggingface.co/models/" + modelId,
50-
API_TYPE_CONF_KEY, OpenAIRequestHandler.Type.HUGGINGFACE.name(),
51-
PATH_CONF_KEY, "",
52-
MODEL_CONF_KEY, modelId
50+
API_TYPE_CONF_KEY, OpenAIRequestHandler.Type.HUGGINGFACE.name()
5351
);
54-
testCall(db, COMPLETION_QUERY,
52+
53+
testCall(db, "CALL apoc.ml.openai.completion('[MASK] is the color of the sky', $apiKey, $conf)",
5554
Map.of("conf", conf, "apiKey", huggingFaceApiKey),
5655
(row) -> {
5756
var result = (Map<String,Object>) row.get("value");
58-
String generatedText = (String) result.get("generated_text");
57+
String generatedText = (String) result.get("sequence");
5958
assertTrue(generatedText.toLowerCase().contains("blue"),
6059
"Actual generatedText is " + generatedText);
6160
});

0 commit comments

Comments
 (0)