Skip to content

Commit f8e5764

Browse files
committed
[NOID] various fixes - added pinecone handler
1 parent 08378f3 commit f8e5764

File tree

6 files changed

+640
-175
lines changed

6 files changed

+640
-175
lines changed
Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,229 @@
1+
package apoc.vectordb;
2+
3+
import static apoc.ml.RestAPIConfig.METHOD_KEY;
4+
import static apoc.vectordb.VectorDb.executeRequest;
5+
import static apoc.vectordb.VectorDb.getEmbeddingResultStream;
6+
import static apoc.vectordb.VectorDbHandler.Type.PINECONE;
7+
import static apoc.vectordb.VectorDbUtil.getCommonVectorDbInfo;
8+
9+
import apoc.Extended;
10+
import apoc.ml.RestAPIConfig;
11+
import apoc.result.MapResult;
12+
import java.util.HashMap;
13+
import java.util.List;
14+
import java.util.Map;
15+
import java.util.stream.Collectors;
16+
import java.util.stream.Stream;
17+
import org.neo4j.graphdb.GraphDatabaseService;
18+
import org.neo4j.graphdb.Transaction;
19+
import org.neo4j.internal.kernel.api.procs.ProcedureCallContext;
20+
import org.neo4j.procedure.Context;
21+
import org.neo4j.procedure.Description;
22+
import org.neo4j.procedure.Mode;
23+
import org.neo4j.procedure.Name;
24+
import org.neo4j.procedure.Procedure;
25+
26+
@Extended
27+
public class Pinecone {
28+
public static final VectorDbHandler DB_HANDLER = PINECONE.get();
29+
30+
@Context
31+
public ProcedureCallContext procedureCallContext;
32+
33+
@Context
34+
public Transaction tx;
35+
36+
@Context
37+
public GraphDatabaseService db;
38+
39+
@Procedure("apoc.vectordb.pinecone.info")
40+
@Description(
41+
"apoc.vectordb.pinecone.info(hostOrKey, index, $configuration) - Get information about the specified existing index or throws an error if it does not exist")
42+
public Stream<MapResult> getInfo(
43+
@Name("hostOrKey") String hostOrKey,
44+
@Name("index") String index,
45+
@Name(value = "configuration", defaultValue = "{}") Map<String, Object> configuration)
46+
throws Exception {
47+
String url = "%s/indexes/%s";
48+
Map<String, Object> config = getVectorDbInfo(hostOrKey, index, configuration, url);
49+
50+
RestAPIConfig restAPIConfig = new RestAPIConfig(config, Map.of(), Map.of());
51+
return executeRequest(restAPIConfig).map(v -> (Map<String, Object>) v).map(MapResult::new);
52+
}
53+
54+
@Procedure("apoc.vectordb.pinecone.createCollection")
55+
@Description(
56+
"apoc.vectordb.pinecone.createCollection(hostOrKey, index, similarity, size, $configuration) - Creates a index, with the name specified in the 2nd parameter, and with the specified `similarity` and `size`")
57+
public Stream<MapResult> createCollection(
58+
@Name("hostOrKey") String hostOrKey,
59+
@Name("index") String index,
60+
@Name("similarity") String similarity,
61+
@Name("size") Long size,
62+
@Name(value = "configuration", defaultValue = "{}") Map<String, Object> configuration)
63+
throws Exception {
64+
String url = "%s/indexes";
65+
Map<String, Object> config = getVectorDbInfo(hostOrKey, index, configuration, url);
66+
config.putIfAbsent(METHOD_KEY, "POST");
67+
68+
Map<String, Object> additionalBodies = Map.of(
69+
"name", index,
70+
"dimension", size,
71+
"metric", similarity);
72+
RestAPIConfig restAPIConfig = new RestAPIConfig(config, Map.of(), additionalBodies);
73+
return executeRequest(restAPIConfig).map(v -> (Map<String, Object>) v).map(MapResult::new);
74+
}
75+
76+
@Procedure("apoc.vectordb.pinecone.deleteCollection")
77+
@Description(
78+
"apoc.vectordb.pinecone.deleteCollection(hostOrKey, index, $configuration) - Deletes a index with the name specified in the 2nd parameter")
79+
public Stream<MapResult> deleteCollection(
80+
@Name("hostOrKey") String hostOrKey,
81+
@Name("index") String index,
82+
@Name(value = "configuration", defaultValue = "{}") Map<String, Object> configuration)
83+
throws Exception {
84+
85+
String url = "%s/indexes/%s";
86+
Map<String, Object> config = getVectorDbInfo(hostOrKey, index, configuration, url);
87+
config.putIfAbsent(METHOD_KEY, "DELETE");
88+
89+
RestAPIConfig restAPIConfig = new RestAPIConfig(config);
90+
return executeRequest(restAPIConfig).map(v -> (Map<String, Object>) v).map(MapResult::new);
91+
}
92+
93+
@Procedure("apoc.vectordb.pinecone.upsert")
94+
@Description(
95+
"apoc.vectordb.pinecone.upsert(hostOrKey, index, vectors, $configuration) - Upserts, in the index with the name specified in the 2nd parameter, the vectors [{id: 'id', vector: '<vectorDb>', medatada: '<metadata>'}]")
96+
public Stream<MapResult> upsert(
97+
@Name("hostOrKey") String hostOrKey,
98+
@Name("index") String index,
99+
@Name("vectors") List<Map<String, Object>> vectors,
100+
@Name(value = "configuration", defaultValue = "{}") Map<String, Object> configuration)
101+
throws Exception {
102+
103+
String url = "%s/vectors/upsert";
104+
105+
Map<String, Object> config = getVectorDbInfo(hostOrKey, index, configuration, url);
106+
config.putIfAbsent(METHOD_KEY, "POST");
107+
108+
vectors = vectors.stream()
109+
.map(i -> {
110+
Map<String, Object> map = new HashMap<>(i);
111+
map.putIfAbsent("values", map.remove("vector"));
112+
return map;
113+
})
114+
.collect(Collectors.toList());
115+
116+
Map<String, Object> additionalBodies = Map.of("vectors", vectors);
117+
RestAPIConfig restAPIConfig = new RestAPIConfig(config, Map.of(), additionalBodies);
118+
return executeRequest(restAPIConfig).map(v -> (Map<String, Object>) v).map(MapResult::new);
119+
}
120+
121+
@Procedure("apoc.vectordb.pinecone.delete")
122+
@Description(
123+
"apoc.vectordb.pinecone.delete(hostOrKey, index, ids, $configuration) - Delete the vectors with the specified `ids`")
124+
public Stream<MapResult> delete(
125+
@Name("hostOrKey") String hostOrKey,
126+
@Name("index") String index,
127+
@Name("vectors") List<Object> ids,
128+
@Name(value = "configuration", defaultValue = "{}") Map<String, Object> configuration)
129+
throws Exception {
130+
131+
String url = "%s/vectors/delete";
132+
Map<String, Object> config = getVectorDbInfo(hostOrKey, index, configuration, url);
133+
config.putIfAbsent(METHOD_KEY, "POST");
134+
135+
Map<String, Object> additionalBodies = Map.of("ids", ids);
136+
RestAPIConfig apiConfig = new RestAPIConfig(config, Map.of(), additionalBodies);
137+
return executeRequest(apiConfig).map(v -> (Map<String, Object>) v).map(MapResult::new);
138+
}
139+
140+
@Procedure(value = "apoc.vectordb.pinecone.get")
141+
@Description(
142+
"apoc.vectordb.pinecone.get(hostOrKey, index, ids, $configuration) - Get the vectors with the specified `ids`")
143+
public Stream<VectorDbUtil.EmbeddingResult> get(
144+
@Name("hostOrKey") String hostOrKey,
145+
@Name("index") String index,
146+
@Name("ids") List<Object> ids,
147+
@Name(value = "configuration", defaultValue = "{}") Map<String, Object> configuration)
148+
throws Exception {
149+
return getCommon(hostOrKey, index, ids, configuration);
150+
}
151+
152+
@Procedure(value = "apoc.vectordb.pinecone.getAndUpdate", mode = Mode.WRITE)
153+
@Description(
154+
"apoc.vectordb.pinecone.getAndUpdate(hostOrKey, index, ids, $configuration) - Get the vectors with the specified `ids`")
155+
public Stream<VectorDbUtil.EmbeddingResult> getAndUpdate(
156+
@Name("hostOrKey") String hostOrKey,
157+
@Name("index") String index,
158+
@Name("ids") List<Object> ids,
159+
@Name(value = "configuration", defaultValue = "{}") Map<String, Object> configuration)
160+
throws Exception {
161+
return getCommon(hostOrKey, index, ids, configuration);
162+
}
163+
164+
private Stream<VectorDbUtil.EmbeddingResult> getCommon(
165+
String hostOrKey, String index, List<Object> ids, Map<String, Object> configuration) throws Exception {
166+
String url = "%s/vectors/fetch";
167+
Map<String, Object> config = getVectorDbInfo(hostOrKey, index, configuration, url);
168+
169+
VectorEmbeddingConfig conf = DB_HANDLER.getEmbedding().fromGet(config, procedureCallContext, ids);
170+
171+
return getEmbeddingResultStream(conf, procedureCallContext, tx, v -> {
172+
Object vectors = ((Map) v).get("vectors");
173+
return ((Map) vectors).values().stream();
174+
});
175+
}
176+
177+
@Procedure(value = "apoc.vectordb.pinecone.query")
178+
@Description(
179+
"apoc.vectordb.pinecone.query(hostOrKey, index, vector, filter, limit, $configuration) - Retrieve closest vectors the the defined `vector`, `limit` of results, in the index with the name specified in the 2nd parameter")
180+
public Stream<VectorDbUtil.EmbeddingResult> query(
181+
@Name("hostOrKey") String hostOrKey,
182+
@Name("index") String index,
183+
@Name(value = "vector", defaultValue = "[]") List<Double> vector,
184+
@Name(value = "filter", defaultValue = "{}") Map<String, Object> filter,
185+
@Name(value = "limit", defaultValue = "10") long limit,
186+
@Name(value = "configuration", defaultValue = "{}") Map<String, Object> configuration)
187+
throws Exception {
188+
return queryCommon(hostOrKey, index, vector, filter, limit, configuration);
189+
}
190+
191+
@Procedure(value = "apoc.vectordb.pinecone.queryAndUpdate", mode = Mode.WRITE)
192+
@Description(
193+
"apoc.vectordb.pinecone.queryAndUpdate(hostOrKey, index, vector, filter, limit, $configuration) - Retrieve closest vectors the the defined `vector`, `limit` of results, in the index with the name specified in the 2nd parameter")
194+
public Stream<VectorDbUtil.EmbeddingResult> queryAndUpdate(
195+
@Name("hostOrKey") String hostOrKey,
196+
@Name("index") String index,
197+
@Name(value = "vector", defaultValue = "[]") List<Double> vector,
198+
@Name(value = "filter", defaultValue = "{}") Map<String, Object> filter,
199+
@Name(value = "limit", defaultValue = "10") long limit,
200+
@Name(value = "configuration", defaultValue = "{}") Map<String, Object> configuration)
201+
throws Exception {
202+
return queryCommon(hostOrKey, index, vector, filter, limit, configuration);
203+
}
204+
205+
private Stream<VectorDbUtil.EmbeddingResult> queryCommon(
206+
String hostOrKey,
207+
String index,
208+
List<Double> vector,
209+
Map<String, Object> filter,
210+
long limit,
211+
Map<String, Object> configuration)
212+
throws Exception {
213+
String url = "%s/query";
214+
Map<String, Object> config = getVectorDbInfo(hostOrKey, index, configuration, url);
215+
216+
VectorEmbeddingConfig conf =
217+
DB_HANDLER.getEmbedding().fromQuery(config, procedureCallContext, vector, filter, limit, index);
218+
219+
return getEmbeddingResultStream(conf, procedureCallContext, tx, v -> {
220+
Map map = (Map) v;
221+
return ((List) map.get("matches")).stream();
222+
});
223+
}
224+
225+
private Map<String, Object> getVectorDbInfo(
226+
String hostOrKey, String index, Map<String, Object> configuration, String templateUrl) {
227+
return getCommonVectorDbInfo(hostOrKey, index, configuration, templateUrl, DB_HANDLER);
228+
}
229+
}
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
package apoc.vectordb;
2+
3+
import static apoc.ml.RestAPIConfig.BODY_KEY;
4+
import static apoc.ml.RestAPIConfig.ENDPOINT_KEY;
5+
import static apoc.ml.RestAPIConfig.HEADERS_KEY;
6+
import static apoc.ml.RestAPIConfig.METHOD_KEY;
7+
import static apoc.util.MapUtil.map;
8+
import static apoc.vectordb.VectorEmbeddingConfig.VECTOR_KEY;
9+
10+
import apoc.ml.RestAPIConfig;
11+
import java.net.URL;
12+
import java.util.HashMap;
13+
import java.util.List;
14+
import java.util.Map;
15+
import java.util.stream.Collectors;
16+
import org.apache.commons.lang3.StringUtils;
17+
import org.neo4j.internal.kernel.api.procs.ProcedureCallContext;
18+
19+
public class PineconeHandler implements VectorDbHandler {
20+
21+
@Override
22+
public String getUrl(String hostOrKey) {
23+
return StringUtils.isBlank(hostOrKey) ? "https://api.pinecone.io" : hostOrKey;
24+
}
25+
26+
@Override
27+
public VectorEmbeddingHandler getEmbedding() {
28+
return new PineconeEmbeddingHandler();
29+
}
30+
31+
@Override
32+
public String getLabel() {
33+
return "Pinecone";
34+
}
35+
36+
@Override
37+
public Map<String, Object> getCredentials(Object credentialsObj, Map<String, Object> config) {
38+
Map headers = (Map) config.getOrDefault(HEADERS_KEY, new HashMap<>());
39+
headers.putIfAbsent("Api-Key", credentialsObj);
40+
config.put(HEADERS_KEY, headers);
41+
return config;
42+
}
43+
44+
// -- embedding handler
45+
static class PineconeEmbeddingHandler implements VectorEmbeddingHandler {
46+
47+
/**
48+
* "method" should be "GET", but is null as a workaround.
49+
* Since with `method: POST` the {@link apoc.util.Util#openUrlConnection(URL, Map)} has a `setChunkedStreamingMode`
50+
* that makes the request to respond 200 OK, but returns an empty result
51+
*/
52+
@Override
53+
public <T> VectorEmbeddingConfig fromGet(
54+
Map<String, Object> config, ProcedureCallContext procedureCallContext, List<T> ids) {
55+
List<String> fields = procedureCallContext.outputFields().collect(Collectors.toList());
56+
57+
config.put(BODY_KEY, null);
58+
59+
String endpoint = (String) config.get(ENDPOINT_KEY);
60+
if (!endpoint.contains("ids=")) {
61+
String idsQueryUrl = ids.stream().map(i -> "ids=" + i).collect(Collectors.joining("&"));
62+
63+
if (endpoint.contains("?")) {
64+
endpoint += "&" + idsQueryUrl;
65+
} else {
66+
endpoint += "?" + idsQueryUrl;
67+
}
68+
}
69+
70+
config.put(ENDPOINT_KEY, endpoint);
71+
return getVectorEmbeddingConfig(config, fields, map());
72+
}
73+
74+
@Override
75+
public VectorEmbeddingConfig fromQuery(
76+
Map<String, Object> config,
77+
ProcedureCallContext procedureCallContext,
78+
List<Double> vector,
79+
Object filter,
80+
long limit,
81+
String index) {
82+
List<String> fields = procedureCallContext.outputFields().collect(Collectors.toList());
83+
84+
Map<String, Object> additionalBodies = map("vector", vector, "filter", filter, "topK", limit);
85+
86+
return getVectorEmbeddingConfig(config, fields, additionalBodies);
87+
}
88+
89+
private VectorEmbeddingConfig getVectorEmbeddingConfig(
90+
Map<String, Object> config, List<String> fields, Map<String, Object> additionalBodies) {
91+
config.putIfAbsent(VECTOR_KEY, "values");
92+
93+
VectorEmbeddingConfig conf = new VectorEmbeddingConfig(config);
94+
95+
additionalBodies.put("includeMetadata", fields.contains("metadata"));
96+
additionalBodies.put("includeValues", fields.contains("vector") && conf.isAllResults());
97+
98+
RestAPIConfig apiConfig = conf.getApiConfig();
99+
Map<String, Object> headers = apiConfig.getHeaders();
100+
headers.remove(METHOD_KEY);
101+
apiConfig.setHeaders(headers);
102+
103+
return VectorEmbeddingHandler.populateApiBodyRequest(conf, additionalBodies);
104+
}
105+
}
106+
}

full/src/main/java/apoc/vectordb/VectorDbHandler.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ default Map<String, Object> getCredentials(Object credentialsObj, Map<String, Ob
2222
enum Type {
2323
CHROMA(new ChromaHandler()),
2424
QDRANT(new QdrantHandler()),
25+
PINECONE(new PineconeHandler()),
2526
WEAVIATE(new WeaviateHandler());
2627

2728
private final VectorDbHandler handler;

full/src/test/java/apoc/util/ExtendedTestUtil.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
package apoc.util;
22

3+
import static apoc.util.TestUtil.testCall;
34
import static apoc.util.TestUtil.testCallAssertions;
5+
import static org.junit.Assert.assertTrue;
6+
import static org.junit.Assert.fail;
47
import static org.neo4j.test.assertion.Assert.assertEventually;
58

69
import java.util.Collections;
@@ -67,4 +70,14 @@ public static void testResultEventually(
6770
timeout,
6871
TimeUnit.SECONDS);
6972
}
73+
74+
public static void assertFails(
75+
GraphDatabaseService db, String query, Map<String, Object> params, String expectedErrMsg) {
76+
try {
77+
testCall(db, query, params, r -> fail("Should fail due to " + expectedErrMsg));
78+
} catch (Exception e) {
79+
String actualErrMsg = e.getMessage();
80+
assertTrue("Actual err. message is: " + actualErrMsg, actualErrMsg.contains(expectedErrMsg));
81+
}
82+
}
7083
}

0 commit comments

Comments
 (0)