Skip to content

Commit 9424ed8

Browse files
authored
Fixes #4145: Add support for huggingface models (#4192)
1 parent 9c9e015 commit 9424ed8

File tree

3 files changed

+116
-5
lines changed

3 files changed

+116
-5
lines changed

docs/asciidoc/modules/ROOT/pages/ml/openai.adoc

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,25 @@ Or via https://localai.io/[LocalAI APIs] (note that the apiKey is `null` by defa
269269
CALL apoc.ml.openai.embedding(['Some Text'], "ignored",
270270
{endpoint: 'http://localhost:8080/v1', model: 'text-embedding-ada-002'})
271271
----
272+
We can use https://huggingface.co/tomasonjo[tomasonjo models] to generate Cypher from text:
273+
274+
[source,cypher]
275+
----
276+
WITH 'Node properties are the following:
277+
Movie {title: STRING, votes: INTEGER, tagline: STRING, released: INTEGER}, Person {born: INTEGER, name: STRING}
278+
Relationship properties are the following:
279+
ACTED_IN {roles: LIST}, REVIEWED {summary: STRING, rating: INTEGER}
280+
The relationships are the following:
281+
(:Person)-[:ACTED_IN]->(:Movie), (:Person)-[:DIRECTED]->(:Movie), (:Person)-[:PRODUCED]->(:Movie), (:Person)-[:WROTE]->(:Movie), (:Person)-[:FOLLOWS]->(:Person), (:Person)-[:REVIEWED]->(:Movie)'
282+
as schema,
283+
'Which actors played in the most movies?' as question
284+
CALL apoc.ml.openai.chat([
285+
{role:"system", content:"Given an input question, convert it to a Cypher query. No pre-amble."},
286+
{role:"user", content:"Based on the Neo4j graph schema below, write a Cypher query that would answer the user's question:
287+
\n "+ schema +" \n\n Question: "+ question +" \n Cypher query:"}
288+
], '<apiKey>', { endpoint: 'http://localhost:8080/chat/completions', model: 'text2cypher-demo-4bit-gguf-unsloth.Q4_K_M.gguf'})
289+
YIELD value RETURN value
290+
----
272291

273292
Or also, by using https://github.com/fardjad/node-llmatic[LLMatic Library]:
274293
[source,cypher]

extended/src/test/java/apoc/ml/OpenAILocalAIIT.java

Lines changed: 85 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,67 @@
11
package apoc.ml;
22

3+
import apoc.coll.Coll;
4+
import apoc.meta.Meta;
5+
import apoc.text.Strings;
36
import apoc.util.TestUtil;
47
import apoc.util.Util;
58
import org.junit.Assume;
69
import org.junit.Before;
710
import org.junit.Rule;
811
import org.junit.Test;
12+
import org.neo4j.graphdb.Transaction;
913
import org.neo4j.test.rule.DbmsRule;
1014
import org.neo4j.test.rule.ImpermanentDbmsRule;
1115

1216
import java.util.List;
1317
import java.util.Map;
1418

1519
import static apoc.ml.MLUtil.*;
16-
import static apoc.ml.OpenAITestResultUtils.*;
20+
import static apoc.ml.OpenAITestResultUtils.CHAT_COMPLETION_QUERY;
21+
import static apoc.ml.OpenAITestResultUtils.COMPLETION_QUERY;
22+
import static apoc.ml.OpenAITestResultUtils.EMBEDDING_QUERY;
23+
import static apoc.ml.OpenAITestResultUtils.TEXT_TO_CYPHER_QUERY;
24+
import static apoc.ml.OpenAITestResultUtils.assertChatCompletion;
25+
import static apoc.ml.OpenAITestResultUtils.assertCompletion;
1726
import static apoc.util.TestUtil.testCall;
1827
import static org.junit.jupiter.api.Assertions.assertEquals;
28+
import static org.junit.jupiter.api.Assertions.assertNotNull;
29+
import static org.junit.jupiter.api.Assertions.assertTrue;
1930

2031

2132
/**
2233
* To start the tests, follow the instructions provided here: https://localai.io/basics/build/
2334
* Then, download the embedding model, as explained here: https://localai.io/models/#embeddings-bert
2435
* Finally, set the env var `LOCAL_AI_URL=http://localhost:<portNumber>/v1`, default is `LOCAL_AI_URL=http://localhost:8080/v1`
36+
*
37+
* To test chatCompletionTomasonjo/text2CypherTomasonjo4Bit run localai with the command below:
38+
* ./local-ai run https://huggingface.co/tomasonjo/text2cypher-demo-4bit-gguf/resolve/main/text2cypher-demo-4bit-gguf-unsloth.Q4_K_M.gguf
2539
*/
2640
public class OpenAILocalAIIT {
2741

2842
private String localAIUrl;
43+
private static final String OPENAI_KEY = System.getenv("OPENAI_KEY");
2944

3045
@Rule
3146
public DbmsRule db = new ImpermanentDbmsRule();
3247

33-
3448
@Before
3549
public void setUp() throws Exception {
3650
localAIUrl = System.getenv("LOCAL_AI_URL");
3751
Assume.assumeNotNull("No LOCAL_AI_URL environment configured", localAIUrl);
38-
TestUtil.registerProcedure(db, OpenAI.class);
52+
TestUtil.registerProcedure(db, OpenAI.class, Prompt.class, Meta.class, Strings.class, Coll.class);
53+
54+
String movies = Util.readResourceFile("movies.cypher");
55+
try (Transaction tx = db.beginTx()) {
56+
tx.execute(movies);
57+
tx.commit();
58+
}
59+
60+
String rag = Util.readResourceFile("rag.cypher");
61+
try (Transaction tx = db.beginTx()) {
62+
tx.execute(rag);
63+
tx.commit();
64+
}
3965
}
4066

4167
@Test
@@ -50,6 +76,62 @@ public void getEmbedding() {
5076
});
5177
}
5278

79+
@Test
80+
public void chatCompletionTomasonjo() {
81+
/*
82+
Useful terminal commands:
83+
# Run models
84+
./local-ai run https://huggingface.co/tomasonjo/text2cypher-demo-4bit-gguf/resolve/main/text2cypher-demo-4bit-gguf-unsloth.Q4_K_M.gguf // List models
85+
./local-ai run https://huggingface.co/tomasonjo/text2cypher-codestral-q4_k_m-gguf/resolve/main/text2cypher-codestral-q4_k_m-gguf-unsloth.Q4_K_M.gguf
86+
# List Models
87+
curl http://localhost:8080/v1/models
88+
# Call model
89+
curl http://localhost:8080/v1/chat/completions -H "Content-Type: application/json" -d '{"model":"text2cypher-demo-4bit-gguf-unsloth.Q4_K_M.gguf", "messages": [{"role": "user", "content": "What is the color of the sky? Answer in one word"}] }'
90+
*/
91+
String[] models = {
92+
"text2cypher-demo-4bit-gguf-unsloth.Q4_K_M.gguf",
93+
"text2cypher-demo-8bit-gguf-unsloth.Q8_0.gguf",
94+
};
95+
96+
for (String model : models) {
97+
testCall(db, CHAT_COMPLETION_QUERY,
98+
getParams(model),
99+
row -> assertChatCompletion(row, model));
100+
}
101+
}
102+
103+
@Test
104+
public void text2CypherTomasonjo4Bit() {
105+
assertNotNull(OPENAI_KEY);
106+
String schema = """
107+
Node properties are the following:
108+
Movie {title: STRING, votes: INTEGER, tagline: STRING, released: INTEGER}, Person {born: INTEGER, name: STRING}
109+
Relationship properties are the following:
110+
ACTED_IN {roles: LIST}, REVIEWED {summary: STRING, rating: INTEGER}
111+
The relationships are the following:
112+
(:Person)-[:ACTED_IN]->(:Movie), (:Person)-[:DIRECTED]->(:Movie), (:Person)-[:PRODUCED]->(:Movie), (:Person)-[:WROTE]->(:Movie), (:Person)-[:FOLLOWS]->(:Person), (:Person)-[:REVIEWED]->(:Movie)
113+
""";
114+
115+
String question = "Which actors played in the most movies?";
116+
String model = "text2cypher-demo-4bit-gguf-unsloth.Q4_K_M.gguf";
117+
118+
Map<String, Object> params = Util.map(
119+
"schema", schema,
120+
"question", question
121+
);
122+
123+
params.putAll(getParams(model));
124+
125+
testCall(db, TEXT_TO_CYPHER_QUERY,
126+
params,
127+
row -> {
128+
String cypherResult = assertChatCompletion(row, model);
129+
// Check that is valid query
130+
long count = TestUtil.count(db, cypherResult);
131+
assertTrue(count > 0);
132+
});
133+
}
134+
53135
@Test
54136
public void completion() {
55137
testCall(db, COMPLETION_QUERY,

extended/src/test/java/apoc/ml/OpenAITestResultUtils.java

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,15 @@ public class OpenAITestResultUtils {
2323
public static final String COMPLETION_QUERY = "CALL apoc.ml.openai.completion('What color is the sky? Answer in one word: ', $apiKey, $conf)";
2424

2525
public static final String COMPLETION_QUERY_EXTENDED_PROMPT = "CALL apoc.ml.openai.completion('\\n\\nHuman: What color is sky?\\n\\nAssistant:', $apiKey, $conf)";
26-
26+
public static final String TEXT_TO_CYPHER_QUERY = """
27+
WITH $schema as schema, $question as question
28+
CALL apoc.ml.openai.chat([
29+
{role:"system", content:"Given an input question, convert it to a Cypher query. No pre-amble."},
30+
{role:"user", content:"Based on the Neo4j graph schema below, write a Cypher query that would answer the user's question:
31+
\\n "+ schema +" \\n\\n Question: "+ question +" \\n Cypher query:"}
32+
], $apiKey, $conf) YIELD value RETURN value
33+
""";
34+
2735
public static void assertEmbeddings(Map<String, Object> row) {
2836
assertEmbeddings(row, 1536);
2937
}
@@ -51,7 +59,7 @@ public static void assertCompletion(Map<String, Object> row, String expectedMode
5159
assertEquals("text_completion", result.get("object"));
5260
}
5361

54-
public static void assertChatCompletion(Map<String, Object> row, String modelId) {
62+
public static String assertChatCompletion(Map<String, Object> row, String modelId) {
5563
var result = (Map<String,Object>) row.get("value");
5664
assertTrue(result.get("created") instanceof Number);
5765
assertTrue(result.containsKey("choices"));
@@ -65,5 +73,7 @@ public static void assertChatCompletion(Map<String, Object> row, String modelId)
6573
assertTrue(((Map) result.get("usage")).get("prompt_tokens") instanceof Number);
6674
assertEquals("chat.completion", result.get("object"));
6775
assertTrue(result.get("model").toString().startsWith(modelId));
76+
77+
return text;
6878
}
6979
}

0 commit comments

Comments
 (0)