Skip to content

Commit b91ace5

Browse files
authored
Fix type inference for ThenInclude with casting. (#466)
1 parent 88a6898 commit b91ace5

File tree

5 files changed

+55
-16
lines changed

5 files changed

+55
-16
lines changed

src/Ardalis.Specification.EntityFrameworkCore/Evaluators/IncludeEvaluator.cs

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,22 +41,19 @@ private IncludeEvaluator() { }
4141
/// <inheritdoc/>
4242
public IQueryable<T> GetQuery<T>(IQueryable<T> query, ISpecification<T> specification) where T : class
4343
{
44-
Type? previousReturnType = null;
4544
foreach (var includeExpression in specification.IncludeExpressions)
4645
{
4746
var lambdaExpr = includeExpression.LambdaExpression;
4847

4948
if (includeExpression.Type == IncludeTypeEnum.Include)
5049
{
5150
var key = new CacheKey(typeof(T), lambdaExpr.ReturnType, null);
52-
previousReturnType = lambdaExpr.ReturnType;
5351
var include = _cache.GetOrAdd(key, CreateIncludeDelegate);
5452
query = (IQueryable<T>)include(query, lambdaExpr);
5553
}
5654
else if (includeExpression.Type == IncludeTypeEnum.ThenInclude)
5755
{
58-
var key = new CacheKey(typeof(T), lambdaExpr.ReturnType, previousReturnType);
59-
previousReturnType = lambdaExpr.ReturnType;
56+
var key = new CacheKey(typeof(T), lambdaExpr.ReturnType, includeExpression.PreviousPropertyType);
6057
var include = _cache.GetOrAdd(key, CreateThenIncludeDelegate);
6158
query = (IQueryable<T>)include(query, lambdaExpr);
6259
}
@@ -104,7 +101,7 @@ private static Func<IQueryable, LambdaExpression, IQueryable> CreateThenIncludeD
104101

105102
private static bool IsGenericEnumerable(Type type, out Type propertyType)
106103
{
107-
if (type.IsGenericType && typeof(IEnumerable).IsAssignableFrom(type))
104+
if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(IEnumerable<>))
108105
{
109106
propertyType = type.GenericTypeArguments[0];
110107
return true;

src/Ardalis.Specification/Builders/Builder_Include.cs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ public static IIncludableSpecificationBuilder<T, TResult, TProperty> Include<T,
101101
{
102102
if (condition)
103103
{
104-
var expr = new IncludeExpressionInfo(navigationSelector, IncludeTypeEnum.Include);
104+
var expr = new IncludeExpressionInfo(navigationSelector);
105105
builder.Specification.Add(expr);
106106
}
107107

@@ -139,7 +139,7 @@ public static IIncludableSpecificationBuilder<T, TProperty> Include<T, TProperty
139139
{
140140
if (condition)
141141
{
142-
var expr = new IncludeExpressionInfo(navigationSelector, IncludeTypeEnum.Include);
142+
var expr = new IncludeExpressionInfo(navigationSelector);
143143
builder.Specification.Add(expr);
144144
}
145145

@@ -183,7 +183,7 @@ public static IIncludableSpecificationBuilder<TEntity, TResult, TProperty> ThenI
183183
{
184184
if (condition && !Specification<TEntity>.IsChainDiscarded)
185185
{
186-
var expr = new IncludeExpressionInfo(navigationSelector, IncludeTypeEnum.ThenInclude);
186+
var expr = new IncludeExpressionInfo(navigationSelector, typeof(TPreviousProperty));
187187
builder.Specification.Add(expr);
188188
}
189189
else
@@ -228,7 +228,7 @@ public static IIncludableSpecificationBuilder<TEntity, TProperty> ThenInclude<TE
228228
{
229229
if (condition && !Specification<TEntity>.IsChainDiscarded)
230230
{
231-
var expr = new IncludeExpressionInfo(navigationSelector, IncludeTypeEnum.ThenInclude);
231+
var expr = new IncludeExpressionInfo(navigationSelector, typeof(TPreviousProperty));
232232
builder.Specification.Add(expr);
233233
}
234234
else
@@ -275,7 +275,7 @@ public static IIncludableSpecificationBuilder<TEntity, TResult, TProperty> ThenI
275275
{
276276
if (condition && !Specification<TEntity>.IsChainDiscarded)
277277
{
278-
var expr = new IncludeExpressionInfo(navigationSelector, IncludeTypeEnum.ThenInclude);
278+
var expr = new IncludeExpressionInfo(navigationSelector, typeof(IEnumerable<TPreviousProperty>));
279279
builder.Specification.Add(expr);
280280
}
281281
else
@@ -320,7 +320,7 @@ public static IIncludableSpecificationBuilder<TEntity, TProperty> ThenInclude<TE
320320
{
321321
if (condition && !Specification<TEntity>.IsChainDiscarded)
322322
{
323-
var expr = new IncludeExpressionInfo(navigationSelector, IncludeTypeEnum.ThenInclude);
323+
var expr = new IncludeExpressionInfo(navigationSelector, typeof(IEnumerable<TPreviousProperty>));
324324
builder.Specification.Add(expr);
325325
}
326326
else

src/Ardalis.Specification/Expressions/IncludeExpressionInfo.cs

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,32 @@ public class IncludeExpressionInfo
1111
/// </summary>
1212
public LambdaExpression LambdaExpression { get; }
1313

14+
/// <summary>
15+
/// The type of the previously included entity.
16+
/// </summary>
17+
public Type? PreviousPropertyType { get; }
18+
1419
/// <summary>
1520
/// The include type.
1621
/// </summary>
1722
public IncludeTypeEnum Type { get; }
1823

19-
public IncludeExpressionInfo(LambdaExpression expression, IncludeTypeEnum includeType)
24+
public IncludeExpressionInfo(LambdaExpression expression)
25+
{
26+
_ = expression ?? throw new ArgumentNullException(nameof(expression));
27+
28+
LambdaExpression = expression;
29+
PreviousPropertyType = null;
30+
Type = IncludeTypeEnum.Include;
31+
}
32+
33+
public IncludeExpressionInfo(LambdaExpression expression, Type previousPropertyType)
2034
{
2135
_ = expression ?? throw new ArgumentNullException(nameof(expression));
36+
_ = previousPropertyType ?? throw new ArgumentNullException(nameof(previousPropertyType));
2237

2338
LambdaExpression = expression;
24-
Type = includeType;
39+
PreviousPropertyType = previousPropertyType;
40+
Type = IncludeTypeEnum.ThenInclude;
2541
}
2642
}

tests/Ardalis.Specification.EntityFrameworkCore.Tests/Evaluators/IncludeEvaluatorTests.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ public void QueriesMatch_GivenInheritanceModel()
3535
var spec = new Specification<Bar>();
3636
spec.Query
3737
.Include(x => x.BarChildren)
38-
.ThenInclude<Bar, BarChild, BarDerivedInfo>(x => (x as BarDerived)!.BarDerivedInfo);
38+
.ThenInclude(x => (x as BarDerived)!.BarDerivedInfo);
3939

4040
var actual = _evaluator
4141
.GetQuery(DbContext.Bars, spec)

tests/Ardalis.Specification.Tests/Expressions/IncludeExpressionInfoTests.cs

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,44 @@ public record City(int Id);
99
[Fact]
1010
public void Constructor_ThrowsArgumentNullException_GivenNullForLambdaExpression()
1111
{
12-
var sut = () => new IncludeExpressionInfo(null!, IncludeTypeEnum.Include);
12+
var sut = () => new IncludeExpressionInfo(null!);
1313

1414
sut.Should().Throw<ArgumentNullException>().WithParameterName("expression");
15+
16+
17+
sut = () => new IncludeExpressionInfo(null!, typeof(Customer));
18+
19+
sut.Should().Throw<ArgumentNullException>().WithParameterName("expression");
20+
}
21+
22+
[Fact]
23+
public void Constructor_ThrowsArgumentNullException_GivenNullForPreviousPropertyType()
24+
{
25+
Expression<Func<Customer, Address>> expr = x => x.Address;
26+
var sut = () => new IncludeExpressionInfo(expr, null!);
27+
28+
sut.Should().Throw<ArgumentNullException>().WithParameterName("previousPropertyType");
1529
}
1630

1731
[Fact]
1832
public void Constructor_GivenIncludeExpression()
1933
{
2034
Expression<Func<Customer, Address>> expr = x => x.Address;
21-
var sut = new IncludeExpressionInfo(expr, IncludeTypeEnum.Include);
35+
var sut = new IncludeExpressionInfo(expr);
2236

2337
sut.Type.Should().Be(IncludeTypeEnum.Include);
2438
sut.LambdaExpression.Should().Be(expr);
2539
}
40+
41+
[Fact]
42+
public void Constructor_GivenThenIncludeExpressionAndPreviousPropertyType()
43+
{
44+
Expression<Func<Address, City>> expr = x => x.City;
45+
var previousPropertyType = typeof(Customer);
46+
var sut = new IncludeExpressionInfo(expr, previousPropertyType);
47+
48+
sut.Type.Should().Be(IncludeTypeEnum.ThenInclude);
49+
sut.LambdaExpression.Should().Be(expr);
50+
sut.PreviousPropertyType.Should().Be(previousPropertyType);
51+
}
2652
}

0 commit comments

Comments
 (0)