Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import com.azure.search.documents.indexes.models.VectorSearchProfile;
import com.azure.search.documents.models.IndexDocumentsResult;
import com.azure.search.documents.models.IndexingResult;
import com.azure.search.documents.models.ScoringParameter;
import com.azure.search.documents.models.SearchOptions;
import com.azure.search.documents.models.VectorQuery;
import com.azure.search.documents.models.VectorizableTextQuery;
Expand Down Expand Up @@ -39,6 +40,7 @@
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Vector;
import java.util.stream.Collectors;
import javax.annotation.Nonnull;
import reactor.core.publisher.Flux;
Expand Down Expand Up @@ -288,37 +290,40 @@ public Mono<Void> deleteBatchAsync(List<String> keys, DeleteRecordOptions option
}).collect(Collectors.toList())).then();
}

private Mono<VectorSearchResults<Record>> searchAndMapAsync(
List<VectorQuery> vectorQueries, VectorSearchOptions options,
GetRecordOptions getRecordOptions) {

private SearchOptions configureVectorSearchOptions(
List<VectorQuery> vectorQueries, VectorSearchOptions options) {
String filter = AzureAISearchVectorStoreCollectionSearchMapping.getInstance()
.getFilter(options.getVectorSearchFilter(), recordDefinition);

SearchOptions searchOptions = new SearchOptions()
.setFilter(filter)
.setTop(options.getTop())
.setSkip(options.getSkip())
.setScoringParameters()
.setVectorSearchOptions(new com.azure.search.documents.models.VectorSearchOptions()
.setQueries(vectorQueries));

if (!options.isIncludeVectors()) {
searchOptions.setSelect(nonVectorFields.toArray(new String[0]));
}

return searchOptions;
}

private Mono<VectorSearchResults<Record>> searchAndMapAsync(String query,
SearchOptions searchOptions,
boolean includeVectors) {
VectorStoreRecordMapper<Record, SearchDocument> mapper = this.options
.getVectorStoreRecordMapper();

return this.searchAsyncClient.search(null, searchOptions)
return this.searchAsyncClient.search(query, searchOptions)
.flatMap(response -> {
Record record;

// Use custom mapper if available
if (mapper != null && mapper.getStorageModelToRecordMapper() != null) {
record = mapper
.mapStorageModelToRecord(response.getDocument(SearchDocument.class),
getRecordOptions);
new GetRecordOptions(includeVectors));
} else {
record = response.getDocument(this.options.getRecordClass());
}
Expand All @@ -329,7 +334,9 @@ record = response.getDocument(this.options.getRecordClass());
}

/**
* Vectorizable text search. This method searches for records that are similar to the given text.
* Vectorizable text search. This method searches for records that are similar to the given text after vectorization.
* <p>
* Vectorizer configuration must be set up in the Azure AI Search index.
*
* @param searchText The text to search with.
* @param options The options to use for the search.
Expand All @@ -353,8 +360,9 @@ public Mono<VectorSearchResults<Record>> searchAsync(String searchText,
: firstVectorFieldName).getEffectiveStorageName())
.setKNearestNeighborsCount(options.getTop()));

return searchAndMapAsync(vectorQueries, options,
new GetRecordOptions(options.isIncludeVectors()));
return searchAndMapAsync(null,
configureVectorSearchOptions(vectorQueries, options),
options.isIncludeVectors());
}

/**
Expand All @@ -367,22 +375,97 @@ public Mono<VectorSearchResults<Record>> searchAsync(String searchText,
@Override
public Mono<VectorSearchResults<Record>> searchAsync(List<Float> vector,
VectorSearchOptions options) {
if (firstVectorFieldName == null) {
throw new SKException("No vector fields defined. Cannot perform vector search");
}
return hybridSearchAsync(null, vector, options, null);
}

if (options == null) {
options = VectorSearchOptions.createDefault(firstVectorFieldName);
/**
* Hybrid search. This method searches for records that are similar to the given text and vector.
*
* @param searchText The text to search with.
* If null, only vector search is performed.
* @param vector The vector to search with.
* If null, only full text search is performed.
* @param options The vector search options used for the search.
* @param additionalSearchOptions AzureAI search additional options.
* If Filter, Top, Skip, Select or VectorSearchOptions are not null, they will be used instead of the default options.
* <p>
* If null, default search options are used.
*/
public Mono<VectorSearchResults<Record>> hybridSearchAsync(String searchText,
List<Float> vector, VectorSearchOptions options, SearchOptions additionalSearchOptions) {
SearchOptions searchOptions = new SearchOptions();

if (vector != null) {
if (firstVectorFieldName == null) {
throw new SKException("No vector fields defined. Cannot perform vector search");
}

if (options == null) {
options = VectorSearchOptions.createDefault(firstVectorFieldName);
}

List<VectorQuery> vectorQueries = new ArrayList<>();
vectorQueries.add(new VectorizedQuery(vector)
.setFields(recordDefinition.getField(options.getVectorFieldName() != null
? options.getVectorFieldName()
: firstVectorFieldName).getEffectiveStorageName())
.setKNearestNeighborsCount(options.getTop()));

// Configure default vector search options
searchOptions = configureVectorSearchOptions(vectorQueries, options);
}

List<VectorQuery> vectorQueries = new ArrayList<>();
vectorQueries.add(new VectorizedQuery(vector)
.setFields(recordDefinition.getField(options.getVectorFieldName() != null
? options.getVectorFieldName()
: firstVectorFieldName).getEffectiveStorageName())
.setKNearestNeighborsCount(options.getTop()));
// Configure additional search options
if (additionalSearchOptions != null) {
searchOptions
.setQueryType(additionalSearchOptions.getQueryType())
.setSemanticSearchOptions(additionalSearchOptions.getSemanticSearchOptions())
.setFacets(additionalSearchOptions.getFacets() != null
? additionalSearchOptions.getFacets().toArray(new String[0])
: null)
.setHighlightFields(additionalSearchOptions.getHighlightFields() != null
? additionalSearchOptions.getHighlightFields().toArray(new String[0])
: null)
.setHighlightPreTag(additionalSearchOptions.getHighlightPreTag())
.setHighlightPostTag(additionalSearchOptions.getHighlightPostTag())
.setMinimumCoverage(additionalSearchOptions.getMinimumCoverage())
.setOrderBy(additionalSearchOptions.getOrderBy() != null
? additionalSearchOptions.getOrderBy().toArray(new String[0])
: null)
.setScoringParameters(additionalSearchOptions.getScoringParameters() != null
? additionalSearchOptions.getScoringParameters().stream()
.map(s -> new ScoringParameter(s.getName(), s.getValues()))
.toArray(ScoringParameter[]::new)
: null)
.setScoringProfile(additionalSearchOptions.getScoringProfile())
.setSearchFields(additionalSearchOptions.getSearchFields() != null
? additionalSearchOptions.getSearchFields().toArray(new String[0])
: null)
.setIncludeTotalCount(additionalSearchOptions.isTotalCountIncluded())
.setSearchMode(additionalSearchOptions.getSearchMode())
.setScoringStatistics(additionalSearchOptions.getScoringStatistics())
.setSessionId(additionalSearchOptions.getSessionId());

// Override default vector options if provided
if (additionalSearchOptions.getFilter() != null) {
searchOptions.setFilter(additionalSearchOptions.getFilter());
}
if (additionalSearchOptions.getTop() != null) {
searchOptions.setTop(additionalSearchOptions.getTop());
}
if (additionalSearchOptions.getSkip() != null) {
searchOptions.setSkip(additionalSearchOptions.getSkip());
}
if (additionalSearchOptions.getVectorSearchOptions() != null) {
searchOptions
.setVectorSearchOptions(additionalSearchOptions.getVectorSearchOptions());
}
if (additionalSearchOptions.getSelect() != null) {
searchOptions.setSelect(additionalSearchOptions.getSelect().toArray(new String[0]));
}
}

return searchAndMapAsync(vectorQueries, options,
new GetRecordOptions(options.isIncludeVectors()));
return searchAndMapAsync(searchText, searchOptions,
options != null && options.isIncludeVectors());
}
}