Skip to content

Commit 337350f

Browse files
authored
Fix Uninstall & UninstallAsync support in DataMigrationManager (#18608)
1 parent fd8d012 commit 337350f

File tree

2 files changed

+263
-13
lines changed

2 files changed

+263
-13
lines changed

src/OrchardCore/OrchardCore.Data.YesSql/Migration/DataMigrationManager.cs

Lines changed: 42 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ public async Task<IEnumerable<string>> GetFeaturesThatNeedUpdateAsync()
8282
return CreateUpgradeLookupTable(dataMigration).ContainsKey(record.Version.Value);
8383
}
8484

85-
return GetMethod(dataMigration, "Create") != null;
85+
return GetCreateMethod(dataMigration) != null;
8686
});
8787

8888
return outOfDateMigrations.Select(m => _typeFeatureProvider.GetFeatureForDependency(m.GetType()).Id).ToArray();
@@ -103,11 +103,22 @@ public async Task Uninstall(string feature)
103103
// get current version for this migration
104104
var dataMigrationRecord = await GetDataMigrationRecordAsync(tempMigration);
105105

106-
var uninstallMethod = GetMethod(migration, "Uninstall");
106+
var uninstallMethod = GetUninstallMethod(migration);
107107

108108
if (uninstallMethod != null)
109109
{
110-
await InvokeMethodAsync(uninstallMethod, migration);
110+
if (uninstallMethod.ReturnType == typeof(Task))
111+
{
112+
await (Task)uninstallMethod.Invoke(migration, []);
113+
}
114+
else if (uninstallMethod.ReturnType == typeof(void))
115+
{
116+
uninstallMethod.Invoke(migration, []);
117+
}
118+
else
119+
{
120+
throw new InvalidOperationException("Invalid return type used in a migration method.");
121+
}
111122
}
112123

113124
if (dataMigrationRecord == null)
@@ -201,15 +212,15 @@ private async Task UpdateAsync(string featureId)
201212
if (current == 0)
202213
{
203214
// Try to get a Create method.
204-
var createMethod = GetMethod(migration, "Create");
215+
var createMethod = GetCreateMethod(migration);
205216

206217
if (createMethod == null)
207218
{
208219
_logger.LogWarning("The migration '{Name}' for '{FeatureName}' does not contain a proper Create or CreateAsync method.", migration.GetType().FullName, featureId);
209220
continue;
210221
}
211222

212-
current = await InvokeMethodAsync(createMethod, migration);
223+
current = await InvokeCreateOrUpdateMethodAsync(createMethod, migration);
213224
}
214225

215226
var lookupTable = CreateUpgradeLookupTable(migration);
@@ -218,7 +229,7 @@ private async Task UpdateAsync(string featureId)
218229
{
219230
_logger.LogInformation("Applying migration for '{Migration}' in '{FeatureId}' from version {Version}.", migration.GetType().FullName, featureId, current);
220231

221-
current = await InvokeMethodAsync(methodInfo, migration);
232+
current = await InvokeCreateOrUpdateMethodAsync(methodInfo, migration);
222233
}
223234

224235
// If current is 0, it means no upgrade/create method was found or succeeded.
@@ -243,7 +254,7 @@ private async Task UpdateAsync(string featureId)
243254
}
244255
}
245256

246-
private static async Task<int> InvokeMethodAsync(MethodInfo method, IDataMigration migration)
257+
private static async Task<int> InvokeCreateOrUpdateMethodAsync(MethodInfo method, IDataMigration migration)
247258
{
248259
if (method.ReturnType == typeof(Task<int>))
249260
{
@@ -307,21 +318,19 @@ private static Tuple<int, MethodInfo> GetUpdateFromMethod(MethodInfo methodInfo)
307318
return null;
308319
}
309320

310-
/// <summary>
311-
/// Returns the method from a data migration class that matches the given name if found.
312-
/// </summary>
313-
private static MethodInfo GetMethod(IDataMigration dataMigration, string name)
321+
private static MethodInfo GetCreateMethod(IDataMigration dataMigration)
314322
{
323+
var methodName = "Create";
315324
// First try to find a method that match the given name. (Ex. Create())
316-
var methodInfo = dataMigration.GetType().GetMethod(name, BindingFlags.Public | BindingFlags.Instance);
325+
var methodInfo = dataMigration.GetType().GetMethod(methodName, BindingFlags.Public | BindingFlags.Instance);
317326

318327
if (methodInfo != null && (methodInfo.ReturnType == typeof(int) || methodInfo.ReturnType == typeof(Task<int>)))
319328
{
320329
return methodInfo;
321330
}
322331

323332
// At this point, try to find a method that matches the given name and ends with Async. (Ex. CreateAsync())
324-
methodInfo = dataMigration.GetType().GetMethod(name + _asyncSuffix, BindingFlags.Public | BindingFlags.Instance);
333+
methodInfo = dataMigration.GetType().GetMethod(methodName + _asyncSuffix, BindingFlags.Public | BindingFlags.Instance);
325334

326335
if (methodInfo != null && methodInfo.ReturnType == typeof(Task<int>))
327336
{
@@ -330,4 +339,24 @@ private static MethodInfo GetMethod(IDataMigration dataMigration, string name)
330339

331340
return null;
332341
}
342+
343+
private static MethodInfo GetUninstallMethod(IDataMigration dataMigration)
344+
{
345+
var methodName = "Uninstall";
346+
var methodInfo = dataMigration.GetType().GetMethod(methodName, BindingFlags.Public | BindingFlags.Instance);
347+
348+
if (methodInfo != null && (methodInfo.ReturnType == typeof(void) || methodInfo.ReturnType == typeof(Task)))
349+
{
350+
return methodInfo;
351+
}
352+
353+
methodInfo = dataMigration.GetType().GetMethod(methodName + _asyncSuffix, BindingFlags.Public | BindingFlags.Instance);
354+
355+
if (methodInfo != null && methodInfo.ReturnType == typeof(Task))
356+
{
357+
return methodInfo;
358+
}
359+
360+
return null;
361+
}
333362
}
Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
using System.Data.Common;
2+
using OrchardCore.Environment.Extensions;
3+
using OrchardCore.Environment.Extensions.Features;
4+
using ISession = YesSql.ISession;
5+
6+
namespace OrchardCore.Data.Migration.Tests;
7+
8+
public class DataMigrationManagerTests
9+
{
10+
[Fact]
11+
public async Task UpdateAsync_ShouldExecuteDataMigration_CreateMethod_OnFreshMigration()
12+
{
13+
// Arrange
14+
var migration1 = new Migration1();
15+
var migration2 = new Migration2();
16+
var migrationManager = GetDataMigrationManager([migration1, migration2]);
17+
18+
// Act
19+
await migrationManager.UpdateAsync("TestFeature");
20+
21+
// Assert
22+
Assert.True(migration1.CreateCalled);
23+
Assert.True(migration2.CreateCalled);
24+
}
25+
26+
[Fact]
27+
public async Task UpdateAsync_ShouldExecuteDataMigration_UpdateFromMethods()
28+
{
29+
// Arrange
30+
var migration1 = new Migration1();
31+
var migration2 = new Migration2();
32+
var migrationManager = GetDataMigrationManager([migration1, migration2]);
33+
34+
// Act
35+
await migrationManager.UpdateAsync("TestFeature");
36+
37+
// Assert
38+
Assert.Equal(2, migration1.UpdateFromCalls);
39+
Assert.Equal(0, migration2.UpdateFromCalls);
40+
}
41+
42+
[Fact]
43+
public async Task Uninstall_ShouldExecuteDataMigration_UninstallMethod()
44+
{
45+
// Arrange
46+
var migration1 = new Migration1();
47+
var migration2 = new Migration2();
48+
var migrationManager = GetDataMigrationManager([migration1, migration2]);
49+
50+
// Act
51+
await migrationManager.Uninstall("TestFeature");
52+
53+
// Assert
54+
Assert.True(migration1.UninstallCalled);
55+
Assert.True(migration2.UninstallCalled);
56+
}
57+
58+
private static DataMigrationManager GetDataMigrationManager(IEnumerable<DataMigration> dataMigrations)
59+
{
60+
var featureInfo = new Mock<IFeatureInfo>();
61+
featureInfo.Setup(f => f.Id).Returns("TestFeature");
62+
63+
var typeFeatureProviderMock = new Mock<ITypeFeatureProvider>();
64+
typeFeatureProviderMock.Setup(m => m.GetFeatureForDependency(It.IsAny<Type>()))
65+
.Returns(featureInfo.Object);
66+
67+
var extensionManagerMock = new Mock<IExtensionManager>();
68+
extensionManagerMock.Setup(m => m.GetFeatureDependencies(It.IsAny<string>()))
69+
.Returns(Enumerable.Empty<IFeatureInfo>());
70+
71+
var sessionMock = new Mock<ISession>();
72+
sessionMock.Setup(s => s.BeginTransactionAsync())
73+
.ReturnsAsync(Mock.Of<DbTransaction>());
74+
75+
sessionMock.Setup(s => s.Query())
76+
.Returns(new FakeQuery());
77+
78+
sessionMock.Setup(s => s.SaveAsync(It.IsAny<object>()))
79+
.Returns(Task.CompletedTask);
80+
81+
var storeMock = new Mock<IStore>();
82+
storeMock.Setup(s => s.Configuration).Returns(new Configuration());
83+
84+
return new DataMigrationManager(
85+
typeFeatureProviderMock.Object,
86+
dataMigrations,
87+
sessionMock.Object,
88+
storeMock.Object,
89+
extensionManagerMock.Object,
90+
NullLogger<DataMigrationManager>.Instance);
91+
}
92+
93+
private sealed class Migration1 : DataMigration
94+
{
95+
public bool CreateCalled { get; private set; }
96+
97+
public bool UninstallCalled { get; private set; }
98+
99+
public int UpdateFromCalls { get; private set; }
100+
101+
public int Create()
102+
{
103+
CreateCalled = true;
104+
105+
return 1;
106+
}
107+
108+
public int UpdateFrom1()
109+
{
110+
++UpdateFromCalls;
111+
112+
return 2;
113+
}
114+
115+
public Task<int> UpdateFrom2Async()
116+
{
117+
++UpdateFromCalls;
118+
119+
return Task.FromResult(3);
120+
}
121+
122+
#pragma warning disable CA1822 // Mark members as static
123+
public int UpdateFromInvalid() => 0;
124+
#pragma warning restore CA1822 // Mark members as static
125+
126+
public void Uninstall() => UninstallCalled = true;
127+
}
128+
129+
private sealed class Migration2 : DataMigration
130+
{
131+
public bool CreateCalled { get; private set; }
132+
133+
public bool UninstallCalled { get; private set; }
134+
135+
public int UpdateFromCalls { get; private set; }
136+
137+
public Task<int> CreateAsync()
138+
{
139+
CreateCalled = true;
140+
141+
return Task.FromResult(1);
142+
}
143+
144+
public Task UninstallAsync()
145+
{
146+
UninstallCalled = true;
147+
148+
return Task.CompletedTask;
149+
}
150+
}
151+
152+
private sealed class FakeQuery : IQuery
153+
{
154+
public IQuery<object> Any()
155+
=> throw new NotImplementedException();
156+
157+
public IQuery<T> For<T>(bool filterType = true) where T : class => new FakeQuery<T>();
158+
159+
IQueryIndex<T> IQuery.ForIndex<T>()
160+
=> throw new NotImplementedException();
161+
}
162+
163+
private sealed class FakeQuery<T> : IQuery<T> where T : class
164+
{
165+
public IQuery<T> All(params Func<IQuery<T>, IQuery<T>>[] predicates)
166+
=> throw new NotImplementedException();
167+
168+
public ValueTask<IQuery<T>> AllAsync(params Func<IQuery<T>, ValueTask<IQuery<T>>>[] predicates)
169+
=> throw new NotImplementedException();
170+
171+
public IQuery<T> Any(params Func<IQuery<T>, IQuery<T>>[] predicates)
172+
=> throw new NotImplementedException();
173+
174+
public ValueTask<IQuery<T>> AnyAsync(params Func<IQuery<T>, ValueTask<IQuery<T>>>[] predicates)
175+
=> throw new NotImplementedException();
176+
177+
public Task<int> CountAsync(CancellationToken cancellationToken = default)
178+
=> throw new NotImplementedException();
179+
180+
public Task<int> CountAsync()
181+
=> throw new NotImplementedException();
182+
183+
public Task<T> FirstOrDefaultAsync(CancellationToken cancellationToken = default)
184+
=> throw new NotImplementedException();
185+
186+
public Task<T> FirstOrDefaultAsync() => Task.FromResult((T)null);
187+
188+
public string GetTypeAlias(Type t)
189+
=> throw new NotImplementedException();
190+
191+
public Task<IEnumerable<T>> ListAsync(CancellationToken cancellationToken = default)
192+
=> throw new NotImplementedException();
193+
194+
public Task<IEnumerable<T>> ListAsync()
195+
=> throw new NotImplementedException();
196+
197+
public IQuery<T> NoDuplicates()
198+
=> throw new NotImplementedException();
199+
200+
public IQuery<T> Skip(int count)
201+
=> throw new NotImplementedException();
202+
203+
public IQuery<T> Take(int count)
204+
=> throw new NotImplementedException();
205+
206+
public IAsyncEnumerable<T> ToAsyncEnumerable(CancellationToken cancellationToken = default)
207+
=> throw new NotImplementedException();
208+
209+
public IAsyncEnumerable<T> ToAsyncEnumerable()
210+
=> throw new NotImplementedException();
211+
212+
public IQuery<T> With(Type indexType)
213+
=> throw new NotImplementedException();
214+
215+
IQuery<T, TIndex> IQuery<T>.With<TIndex>()
216+
=> throw new NotImplementedException();
217+
218+
IQuery<T, TIndex> IQuery<T>.With<TIndex>(System.Linq.Expressions.Expression<Func<TIndex, bool>> predicate)
219+
=> throw new NotImplementedException();
220+
}
221+
}

0 commit comments

Comments
 (0)