From 021411f4dfdb589c4db9c3e6de696536d07e1d8d Mon Sep 17 00:00:00 2001 From: Milder Hernandez Cagua Date: Tue, 19 Nov 2024 02:28:30 -0800 Subject: [PATCH] Add hybridSearchAsync to AzureAISearchVectorStoreRecordCollection --- ...reAISearchVectorStoreRecordCollection.java | 129 ++++++++++++++---- 1 file changed, 106 insertions(+), 23 deletions(-) diff --git a/data/semantickernel-data-azureaisearch/src/main/java/com/microsoft/semantickernel/data/azureaisearch/AzureAISearchVectorStoreRecordCollection.java b/data/semantickernel-data-azureaisearch/src/main/java/com/microsoft/semantickernel/data/azureaisearch/AzureAISearchVectorStoreRecordCollection.java index c4f44c42..b1c15b82 100644 --- a/data/semantickernel-data-azureaisearch/src/main/java/com/microsoft/semantickernel/data/azureaisearch/AzureAISearchVectorStoreRecordCollection.java +++ b/data/semantickernel-data-azureaisearch/src/main/java/com/microsoft/semantickernel/data/azureaisearch/AzureAISearchVectorStoreRecordCollection.java @@ -10,6 +10,7 @@ import com.azure.search.documents.indexes.models.VectorSearchProfile; import com.azure.search.documents.models.IndexDocumentsResult; import com.azure.search.documents.models.IndexingResult; +import com.azure.search.documents.models.ScoringParameter; import com.azure.search.documents.models.SearchOptions; import com.azure.search.documents.models.VectorQuery; import com.azure.search.documents.models.VectorizableTextQuery; @@ -39,6 +40,7 @@ import java.util.HashSet; import java.util.Iterator; import java.util.List; +import java.util.Vector; import java.util.stream.Collectors; import javax.annotation.Nonnull; import reactor.core.publisher.Flux; @@ -288,10 +290,8 @@ public Mono deleteBatchAsync(List keys, DeleteRecordOptions option }).collect(Collectors.toList())).then(); } - private Mono> searchAndMapAsync( - List vectorQueries, VectorSearchOptions options, - GetRecordOptions getRecordOptions) { - + private SearchOptions configureVectorSearchOptions( + List vectorQueries, VectorSearchOptions options) { String filter = AzureAISearchVectorStoreCollectionSearchMapping.getInstance() .getFilter(options.getVectorSearchFilter(), recordDefinition); @@ -299,7 +299,6 @@ private Mono> searchAndMapAsync( .setFilter(filter) .setTop(options.getTop()) .setSkip(options.getSkip()) - .setScoringParameters() .setVectorSearchOptions(new com.azure.search.documents.models.VectorSearchOptions() .setQueries(vectorQueries)); @@ -307,10 +306,16 @@ private Mono> searchAndMapAsync( searchOptions.setSelect(nonVectorFields.toArray(new String[0])); } + return searchOptions; + } + + private Mono> searchAndMapAsync(String query, + SearchOptions searchOptions, + boolean includeVectors) { VectorStoreRecordMapper mapper = this.options .getVectorStoreRecordMapper(); - return this.searchAsyncClient.search(null, searchOptions) + return this.searchAsyncClient.search(query, searchOptions) .flatMap(response -> { Record record; @@ -318,7 +323,7 @@ private Mono> searchAndMapAsync( if (mapper != null && mapper.getStorageModelToRecordMapper() != null) { record = mapper .mapStorageModelToRecord(response.getDocument(SearchDocument.class), - getRecordOptions); + new GetRecordOptions(includeVectors)); } else { record = response.getDocument(this.options.getRecordClass()); } @@ -329,7 +334,9 @@ record = response.getDocument(this.options.getRecordClass()); } /** - * Vectorizable text search. This method searches for records that are similar to the given text. + * Vectorizable text search. This method searches for records that are similar to the given text after vectorization. + *

+ * Vectorizer configuration must be set up in the Azure AI Search index. * * @param searchText The text to search with. * @param options The options to use for the search. @@ -353,8 +360,9 @@ public Mono> searchAsync(String searchText, : firstVectorFieldName).getEffectiveStorageName()) .setKNearestNeighborsCount(options.getTop())); - return searchAndMapAsync(vectorQueries, options, - new GetRecordOptions(options.isIncludeVectors())); + return searchAndMapAsync(null, + configureVectorSearchOptions(vectorQueries, options), + options.isIncludeVectors()); } /** @@ -367,22 +375,97 @@ public Mono> searchAsync(String searchText, @Override public Mono> searchAsync(List vector, VectorSearchOptions options) { - if (firstVectorFieldName == null) { - throw new SKException("No vector fields defined. Cannot perform vector search"); - } + return hybridSearchAsync(null, vector, options, null); + } - if (options == null) { - options = VectorSearchOptions.createDefault(firstVectorFieldName); + /** + * Hybrid search. This method searches for records that are similar to the given text and vector. + * + * @param searchText The text to search with. + * If null, only vector search is performed. + * @param vector The vector to search with. + * If null, only full text search is performed. + * @param options The vector search options used for the search. + * @param additionalSearchOptions AzureAI search additional options. + * If Filter, Top, Skip, Select or VectorSearchOptions are not null, they will be used instead of the default options. + *

+ * If null, default search options are used. + */ + public Mono> hybridSearchAsync(String searchText, + List vector, VectorSearchOptions options, SearchOptions additionalSearchOptions) { + SearchOptions searchOptions = new SearchOptions(); + + if (vector != null) { + if (firstVectorFieldName == null) { + throw new SKException("No vector fields defined. Cannot perform vector search"); + } + + if (options == null) { + options = VectorSearchOptions.createDefault(firstVectorFieldName); + } + + List vectorQueries = new ArrayList<>(); + vectorQueries.add(new VectorizedQuery(vector) + .setFields(recordDefinition.getField(options.getVectorFieldName() != null + ? options.getVectorFieldName() + : firstVectorFieldName).getEffectiveStorageName()) + .setKNearestNeighborsCount(options.getTop())); + + // Configure default vector search options + searchOptions = configureVectorSearchOptions(vectorQueries, options); } - List vectorQueries = new ArrayList<>(); - vectorQueries.add(new VectorizedQuery(vector) - .setFields(recordDefinition.getField(options.getVectorFieldName() != null - ? options.getVectorFieldName() - : firstVectorFieldName).getEffectiveStorageName()) - .setKNearestNeighborsCount(options.getTop())); + // Configure additional search options + if (additionalSearchOptions != null) { + searchOptions + .setQueryType(additionalSearchOptions.getQueryType()) + .setSemanticSearchOptions(additionalSearchOptions.getSemanticSearchOptions()) + .setFacets(additionalSearchOptions.getFacets() != null + ? additionalSearchOptions.getFacets().toArray(new String[0]) + : null) + .setHighlightFields(additionalSearchOptions.getHighlightFields() != null + ? additionalSearchOptions.getHighlightFields().toArray(new String[0]) + : null) + .setHighlightPreTag(additionalSearchOptions.getHighlightPreTag()) + .setHighlightPostTag(additionalSearchOptions.getHighlightPostTag()) + .setMinimumCoverage(additionalSearchOptions.getMinimumCoverage()) + .setOrderBy(additionalSearchOptions.getOrderBy() != null + ? additionalSearchOptions.getOrderBy().toArray(new String[0]) + : null) + .setScoringParameters(additionalSearchOptions.getScoringParameters() != null + ? additionalSearchOptions.getScoringParameters().stream() + .map(s -> new ScoringParameter(s.getName(), s.getValues())) + .toArray(ScoringParameter[]::new) + : null) + .setScoringProfile(additionalSearchOptions.getScoringProfile()) + .setSearchFields(additionalSearchOptions.getSearchFields() != null + ? additionalSearchOptions.getSearchFields().toArray(new String[0]) + : null) + .setIncludeTotalCount(additionalSearchOptions.isTotalCountIncluded()) + .setSearchMode(additionalSearchOptions.getSearchMode()) + .setScoringStatistics(additionalSearchOptions.getScoringStatistics()) + .setSessionId(additionalSearchOptions.getSessionId()); + + // Override default vector options if provided + if (additionalSearchOptions.getFilter() != null) { + searchOptions.setFilter(additionalSearchOptions.getFilter()); + } + if (additionalSearchOptions.getTop() != null) { + searchOptions.setTop(additionalSearchOptions.getTop()); + } + if (additionalSearchOptions.getSkip() != null) { + searchOptions.setSkip(additionalSearchOptions.getSkip()); + } + if (additionalSearchOptions.getVectorSearchOptions() != null) { + searchOptions + .setVectorSearchOptions(additionalSearchOptions.getVectorSearchOptions()); + } + if (additionalSearchOptions.getSelect() != null) { + searchOptions.setSelect(additionalSearchOptions.getSelect().toArray(new String[0])); + } + } - return searchAndMapAsync(vectorQueries, options, - new GetRecordOptions(options.isIncludeVectors())); + return searchAndMapAsync(searchText, searchOptions, + options != null && options.isIncludeVectors()); } }