diff --git a/testing/scorer/scorer-core/src/main/java/io/quarkiverse/langchain4j/testing/scorer/Scorer.java b/testing/scorer/scorer-core/src/main/java/io/quarkiverse/langchain4j/testing/scorer/Scorer.java index 4bbf4517b..4626f522b 100644 --- a/testing/scorer/scorer-core/src/main/java/io/quarkiverse/langchain4j/testing/scorer/Scorer.java +++ b/testing/scorer/scorer-core/src/main/java/io/quarkiverse/langchain4j/testing/scorer/Scorer.java @@ -1,31 +1,21 @@ package io.quarkiverse.langchain4j.testing.scorer; -import java.io.Closeable; import java.util.Comparator; import java.util.List; import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; import java.util.function.Function; import org.jboss.logging.Logger; -public class Scorer implements Closeable { +public class Scorer { private static final Logger LOG = Logger.getLogger(Scorer.class); private final ExecutorService executor; - public Scorer(int concurrency) { - if (concurrency > 1) { - executor = Executors.newFixedThreadPool(concurrency); - } else { - executor = Executors.newSingleThreadExecutor(); - } - } - - public Scorer() { - this(1); + public Scorer(ExecutorService executor) { + this.executor = executor; } @SuppressWarnings({ "unchecked" }) @@ -74,10 +64,6 @@ public EvaluationReport evaluate( return new EvaluationReport<>(orderedEvalutions); } - public void close() { - executor.shutdown(); - } - public record EvaluationResult( EvaluationSample sample, T result, Throwable thrown, boolean passed) { public static EvaluationResult fromCompletedEvaluation( diff --git a/testing/scorer/scorer-core/src/test/java/io/quarkiverse/langchain4j/testing/scorer/ScorerTest.java b/testing/scorer/scorer-core/src/test/java/io/quarkiverse/langchain4j/testing/scorer/ScorerTest.java index 279e20887..446a62045 100644 --- a/testing/scorer/scorer-core/src/test/java/io/quarkiverse/langchain4j/testing/scorer/ScorerTest.java +++ b/testing/scorer/scorer-core/src/test/java/io/quarkiverse/langchain4j/testing/scorer/ScorerTest.java @@ -3,6 +3,8 @@ import static org.assertj.core.api.Assertions.assertThat; import java.util.List; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; import java.util.function.Function; import java.util.stream.Stream; @@ -12,18 +14,20 @@ class ScorerTest { private Scorer scorer; + private ExecutorService executor; @AfterEach void tearDown() { - if (scorer != null) { - scorer.close(); + if (executor != null) { + executor.shutdown(); } } @SuppressWarnings("unchecked") @Test void evaluateShouldReturnCorrectReport() { - scorer = new Scorer(2); + executor = Executors.newFixedThreadPool(2); + scorer = new Scorer(executor); EvaluationSample sample1 = new EvaluationSample<>( "Sample1", @@ -62,7 +66,8 @@ void evaluateShouldReturnCorrectReport() { @SuppressWarnings("unchecked") @Test void evaluateShouldReturnCorrectlyOrderedReport() { - scorer = new Scorer(2); + executor = Executors.newFixedThreadPool(2); + scorer = new Scorer(executor); var sleeps = Stream.of(25l, 0l); var samples = new Samples<>( sleeps @@ -93,7 +98,8 @@ private String sleep(Parameters params) { @Test @SuppressWarnings("unchecked") void evaluateShouldHandleExceptionsInFunction() { - scorer = new Scorer(); + executor = Executors.newSingleThreadExecutor(); + scorer = new Scorer(executor); EvaluationSample sample = new EvaluationSample<>( "Sample1", new Parameters().add(new Parameter.UnnamedParameter("param1")), @@ -118,7 +124,8 @@ void evaluateShouldHandleExceptionsInFunction() { @Test @SuppressWarnings("unchecked") void evaluateShouldHandleMultipleStrategies() { - scorer = new Scorer(); + executor = Executors.newSingleThreadExecutor(); + scorer = new Scorer(executor); EvaluationSample sample = new EvaluationSample<>( "Sample1", diff --git a/testing/scorer/scorer-junit5/pom.xml b/testing/scorer/scorer-junit5/pom.xml index a14109da0..115ace71c 100644 --- a/testing/scorer/scorer-junit5/pom.xml +++ b/testing/scorer/scorer-junit5/pom.xml @@ -32,7 +32,15 @@ io.quarkiverse.langchain4j quarkus-langchain4j-testing-scorer-core - 999-SNAPSHOT + ${project.version} + + + io.quarkus + quarkus-junit5 + + + io.quarkus + quarkus-smallrye-context-propagation @@ -48,4 +56,4 @@ - \ No newline at end of file + diff --git a/testing/scorer/scorer-junit5/src/main/java/io/quarkiverse/langchain4j/scorer/junit5/ScorerConfiguration.java b/testing/scorer/scorer-junit5/src/main/java/io/quarkiverse/langchain4j/scorer/junit5/ScorerConfiguration.java deleted file mode 100644 index 85ba53027..000000000 --- a/testing/scorer/scorer-junit5/src/main/java/io/quarkiverse/langchain4j/scorer/junit5/ScorerConfiguration.java +++ /dev/null @@ -1,23 +0,0 @@ -package io.quarkiverse.langchain4j.scorer.junit5; - -import static java.lang.annotation.ElementType.*; -import static java.lang.annotation.RetentionPolicy.*; - -import java.lang.annotation.Retention; -import java.lang.annotation.Target; - -/** - * Allows configuring the number of threads to use for the evaluation. - * The target of this annotation should be a parameter or a field of type - * {@link io.quarkiverse.langchain4j.testing.scorer.Scorer}. - */ -@Retention(RUNTIME) -@Target({ FIELD, PARAMETER }) -public @interface ScorerConfiguration { - - /** - * @return the number of threads to use for the evaluation. - */ - int concurrency() default 1; - -} diff --git a/testing/scorer/scorer-junit5/src/main/java/io/quarkiverse/langchain4j/scorer/junit5/ScorerExtension.java b/testing/scorer/scorer-junit5/src/main/java/io/quarkiverse/langchain4j/scorer/junit5/ScorerExtension.java index 6173f3dee..7457d721a 100644 --- a/testing/scorer/scorer-junit5/src/main/java/io/quarkiverse/langchain4j/scorer/junit5/ScorerExtension.java +++ b/testing/scorer/scorer-junit5/src/main/java/io/quarkiverse/langchain4j/scorer/junit5/ScorerExtension.java @@ -1,85 +1,28 @@ package io.quarkiverse.langchain4j.scorer.junit5; -import java.lang.reflect.Field; -import java.util.List; -import java.util.Optional; -import java.util.concurrent.CopyOnWriteArrayList; - -import org.junit.jupiter.api.extension.AfterEachCallback; -import org.junit.jupiter.api.extension.BeforeEachCallback; import org.junit.jupiter.api.extension.ExtensionContext; import org.junit.jupiter.api.extension.ParameterContext; import org.junit.jupiter.api.extension.ParameterResolutionException; import org.junit.jupiter.api.extension.ParameterResolver; -import org.junit.platform.commons.support.HierarchyTraversalMode; -import org.junit.platform.commons.support.ReflectionSupport; import io.quarkiverse.langchain4j.testing.scorer.Samples; -import io.quarkiverse.langchain4j.testing.scorer.Scorer; import io.quarkiverse.langchain4j.testing.scorer.YamlLoader; -public class ScorerExtension implements BeforeEachCallback, AfterEachCallback, ParameterResolver { - private final List scorers = new CopyOnWriteArrayList<>(); - - @Override - public void beforeEach(ExtensionContext extensionContext) { - Optional> maybeClass = extensionContext.getTestClass(); - if (maybeClass.isPresent()) { - List fields = ReflectionSupport.findFields(maybeClass.get(), - field -> field.getType().isAssignableFrom(Scorer.class), HierarchyTraversalMode.TOP_DOWN); - for (Field field : fields) { - Scorer sc; - if (field.isAnnotationPresent(ScorerConfiguration.class)) { - ScorerConfiguration annotation = field.getAnnotation(ScorerConfiguration.class); - sc = new Scorer(annotation.concurrency()); - } else { - sc = new Scorer(); - } - scorers.add(sc); - inject(sc, extensionContext.getRequiredTestInstance(), field); - } - } - } - - private void inject(Scorer sc, Object instance, Field field) { - try { - field.setAccessible(true); - field.set(instance, sc); - } catch (IllegalAccessException e) { - throw new RuntimeException(e); - } - } - - @Override - public void afterEach(ExtensionContext extensionContext) { - for (Scorer scorer : scorers) { - scorer.close(); - } - } +public class ScorerExtension implements ParameterResolver { @Override public boolean supportsParameter(ParameterContext parameterContext, ExtensionContext extensionContext) throws ParameterResolutionException { return (parameterContext.findAnnotation(SampleLocation.class).isPresent() - && parameterContext.getParameter().getType().isAssignableFrom(Samples.class)) - || parameterContext.getParameter().getType().isAssignableFrom(Scorer.class); + && parameterContext.getParameter().getType().isAssignableFrom(Samples.class)); } @Override public Object resolveParameter(ParameterContext parameterContext, ExtensionContext extensionContext) throws ParameterResolutionException { - if (parameterContext.getParameter().getType().isAssignableFrom(Scorer.class)) { - if (parameterContext.getParameter().isAnnotationPresent(ScorerConfiguration.class)) { - ScorerConfiguration annotation = parameterContext.getParameter().getAnnotation(ScorerConfiguration.class); - return new Scorer(annotation.concurrency()); - } else { - return new Scorer(); - } - } else { - // List of data samples - String path = parameterContext.findAnnotation(SampleLocation.class).orElseThrow().value(); - return YamlLoader.load(path); - } + // List of data samples + String path = parameterContext.findAnnotation(SampleLocation.class).orElseThrow().value(); + return YamlLoader.load(path); } } diff --git a/testing/scorer/scorer-junit5/src/main/java/io/quarkiverse/langchain4j/scorer/junit5/ScorerProducer.java b/testing/scorer/scorer-junit5/src/main/java/io/quarkiverse/langchain4j/scorer/junit5/ScorerProducer.java new file mode 100644 index 000000000..1fc62f62d --- /dev/null +++ b/testing/scorer/scorer-junit5/src/main/java/io/quarkiverse/langchain4j/scorer/junit5/ScorerProducer.java @@ -0,0 +1,15 @@ +package io.quarkiverse.langchain4j.scorer.junit5; + +import jakarta.enterprise.context.ApplicationScoped; + +import org.eclipse.microprofile.context.ManagedExecutor; + +import io.quarkiverse.langchain4j.testing.scorer.Scorer; + +public class ScorerProducer { + + @ApplicationScoped + public Scorer scorer(ManagedExecutor executor) { + return new Scorer(executor); + } +} diff --git a/testing/scorer/scorer-junit5/src/main/resources/META-INF/beans.xml b/testing/scorer/scorer-junit5/src/main/resources/META-INF/beans.xml new file mode 100644 index 000000000..e69de29bb diff --git a/testing/scorer/scorer-junit5/src/test/java/io/quarkiverse/langchain4j/scorer/junit5/test/ScorerExtensionTest.java b/testing/scorer/scorer-junit5/src/test/java/io/quarkiverse/langchain4j/scorer/junit5/test/ScorerExtensionTest.java index da9ab6d36..535c5a08e 100644 --- a/testing/scorer/scorer-junit5/src/test/java/io/quarkiverse/langchain4j/scorer/junit5/test/ScorerExtensionTest.java +++ b/testing/scorer/scorer-junit5/src/test/java/io/quarkiverse/langchain4j/scorer/junit5/test/ScorerExtensionTest.java @@ -4,47 +4,18 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; -import org.mockito.Mockito; import io.quarkiverse.langchain4j.scorer.junit5.SampleLocation; -import io.quarkiverse.langchain4j.scorer.junit5.ScorerConfiguration; import io.quarkiverse.langchain4j.scorer.junit5.ScorerExtension; import io.quarkiverse.langchain4j.testing.scorer.Samples; -import io.quarkiverse.langchain4j.testing.scorer.Scorer; @ExtendWith(ScorerExtension.class) class ScorerExtensionTest { - @ScorerConfiguration(concurrency = 3) - private Scorer scorerWithConcurrency; - - private Scorer defaultScorer; - - @Test - void scorerFieldInjectionShouldWork() { - assertThat(scorerWithConcurrency).isNotNull(); - assertThat(scorerWithConcurrency).extracting("executor").isNotNull(); - assertThat(defaultScorer).isNotNull(); - assertThat(defaultScorer).extracting("executor").isNotNull(); - } - - @Test - void scorerParameterShouldBeResolved(@ScorerConfiguration(concurrency = 2) Scorer scorer) { - assertThat(scorer).isNotNull(); - assertThat(scorer).extracting("executor").isNotNull(); - } - @Test void samplesParameterShouldBeResolved(@SampleLocation("src/test/resources/test-samples.yaml") Samples samples) { assertThat(samples).isNotNull(); assertThat(samples).hasSizeGreaterThan(0); assertThat(samples.get(0).name()).isEqualTo("Sample1"); // Assuming the YAML has this entry. } - - @Test - void scorerShouldBeClosedAfterTest() { - Scorer mockScorer = Mockito.mock(Scorer.class); - mockScorer.close(); - Mockito.verify(mockScorer).close(); - } }