diff --git a/src/NzbDrone.Core.Test/Datastore/BasicRepositoryFixture.cs b/src/NzbDrone.Core.Test/Datastore/BasicRepositoryFixture.cs index fac7e7f00..eb9f4ef18 100644 --- a/src/NzbDrone.Core.Test/Datastore/BasicRepositoryFixture.cs +++ b/src/NzbDrone.Core.Test/Datastore/BasicRepositoryFixture.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Threading.Tasks; using FizzWare.NBuilder; using FluentAssertions; using NUnit.Framework; @@ -33,277 +34,277 @@ public void Setup() } [Test] - public void should_be_able_to_insert() + public async Task should_be_able_to_insert() { - Subject.Insert(_basicList[0]); - Subject.All().Should().HaveCount(1); + await Subject.InsertAsync(_basicList[0]); + (await Subject.AllAsync().ToListAsync()).Should().HaveCount(1); } [Test] - public void should_be_able_to_insert_many() + public async Task should_be_able_to_insert_many() { - Subject.InsertMany(_basicList); - Subject.All().Should().HaveCount(5); + await Subject.InsertManyAsync(_basicList); + (await Subject.AllAsync().ToListAsync()).Should().HaveCount(5); } [Test] public void insert_many_should_throw_if_id_not_zero() { _basicList[1].Id = 999; - Assert.Throws(() => Subject.InsertMany(_basicList)); + Assert.ThrowsAsync(() => Subject.InsertManyAsync(_basicList)); } [Test] - public void should_be_able_to_get_count() + public async Task should_be_able_to_get_count() { - Subject.InsertMany(_basicList); - Subject.Count().Should().Be(_basicList.Count); + await Subject.InsertManyAsync(_basicList); + (await Subject.CountAsync()).Should().Be(_basicList.Count); } [Test] - public void should_be_able_to_find_by_id() + public async Task should_be_able_to_find_by_id() { - Subject.InsertMany(_basicList); - var storeObject = Subject.Get(_basicList[1].Id); + await Subject.InsertManyAsync(_basicList); + var storeObject = await Subject.GetAsync(_basicList[1].Id); storeObject.Should().BeEquivalentTo(_basicList[1], o => o.IncludingAllRuntimeProperties()); } [Test] - public void should_be_able_to_update() + public async Task should_be_able_to_update() { - Subject.InsertMany(_basicList); + await Subject.InsertManyAsync(_basicList); var item = _basicList[1]; item.Interval = 999; - Subject.Update(item); + await Subject.UpdateAsync(item); - Subject.All().Should().BeEquivalentTo(_basicList); + (await Subject.AllAsync().ToListAsync()).Should().BeEquivalentTo(_basicList); } [Test] - public void should_be_able_to_upsert_new() + public async Task should_be_able_to_upsert_new() { - Subject.Upsert(_basicList[0]); - Subject.All().Should().HaveCount(1); + await Subject.UpsertAsync(_basicList[0]); + (await Subject.AllAsync().ToListAsync()).Should().HaveCount(1); } [Test] - public void should_be_able_to_upsert_existing() + public async Task should_be_able_to_upsert_existing() { - Subject.InsertMany(_basicList); + await Subject.InsertManyAsync(_basicList); var item = _basicList[1]; item.Interval = 999; - Subject.Upsert(item); + await Subject.UpsertAsync(item); - Subject.All().Should().BeEquivalentTo(_basicList); + (await Subject.AllAsync().ToListAsync()).Should().BeEquivalentTo(_basicList); } [Test] - public void should_be_able_to_update_single_field() + public async Task should_be_able_to_update_single_field() { - Subject.InsertMany(_basicList); + await Subject.InsertManyAsync(_basicList); var item = _basicList[1]; var executionBackup = item.LastExecution; item.Interval = 999; item.LastExecution = DateTime.UtcNow; - Subject.SetFields(item, x => x.Interval); + await Subject.SetFieldsAsync(item, x => x.Interval); item.LastExecution = executionBackup; - Subject.All().Should().BeEquivalentTo(_basicList); + (await Subject.AllAsync().ToListAsync()).Should().BeEquivalentTo(_basicList); } [Test] - public void set_fields_should_throw_if_id_zero() + public async Task set_fields_should_throw_if_id_zero() { - Subject.InsertMany(_basicList); + await Subject.InsertManyAsync(_basicList); _basicList[1].Id = 0; _basicList[1].LastExecution = DateTime.UtcNow; - Assert.Throws(() => Subject.SetFields(_basicList[1], x => x.Interval)); + Assert.ThrowsAsync(() => Subject.SetFieldsAsync(_basicList[1], x => x.Interval)); } [Test] - public void should_be_able_to_delete_model_by_id() + public async Task should_be_able_to_delete_model_by_id() { - Subject.InsertMany(_basicList); - Subject.All().Should().HaveCount(5); + await Subject.InsertManyAsync(_basicList); + (await Subject.AllAsync().ToListAsync()).Should().HaveCount(5); - Subject.Delete(_basicList[0].Id); - Subject.All().Select(x => x.Id).Should().BeEquivalentTo(_basicList.Skip(1).Select(x => x.Id)); + await Subject.DeleteAsync(_basicList[0].Id); + (await Subject.AllAsync().ToListAsync()).Select(x => x.Id).Should().BeEquivalentTo(_basicList.Skip(1).Select(x => x.Id)); } [Test] - public void should_be_able_to_delete_model_by_object() + public async Task should_be_able_to_delete_model_by_object() { - Subject.InsertMany(_basicList); - Subject.All().Should().HaveCount(5); + await Subject.InsertManyAsync(_basicList); + (await Subject.AllAsync().ToListAsync()).Should().HaveCount(5); - Subject.Delete(_basicList[0]); - Subject.All().Select(x => x.Id).Should().BeEquivalentTo(_basicList.Skip(1).Select(x => x.Id)); + await Subject.DeleteAsync(_basicList[0]); + (await Subject.AllAsync().ToListAsync()).Select(x => x.Id).Should().BeEquivalentTo(_basicList.Skip(1).Select(x => x.Id)); } [Test] - public void get_many_should_return_empty_list_if_no_ids() + public async Task get_many_should_return_empty_list_if_no_ids() { - Subject.Get(new List()).Should().BeEquivalentTo(new List()); + (await Subject.GetAsync(new List()).ToListAsync()).Should().BeEquivalentTo(new List()); } [Test] - public void get_many_should_throw_if_not_all_found() + public async Task get_many_should_throw_if_not_all_found() { - Subject.InsertMany(_basicList); - Assert.Throws(() => Subject.Get(new[] { 999 })); + await Subject.InsertManyAsync(_basicList); + Assert.ThrowsAsync(async () => await Subject.GetAsync([999]).ToListAsync()); } [Test] - public void should_be_able_to_find_by_multiple_id() + public async Task should_be_able_to_find_by_multiple_id() { - Subject.InsertMany(_basicList); - var storeObject = Subject.Get(_basicList.Take(2).Select(x => x.Id)); + await Subject.InsertManyAsync(_basicList); + var storeObject = await Subject.GetAsync(_basicList.Take(2).Select(x => x.Id)).ToListAsync(); storeObject.Select(x => x.Id).Should().BeEquivalentTo(_basicList.Take(2).Select(x => x.Id)); } [Test] - public void should_be_able_to_update_many() + public async Task should_be_able_to_update_many() { - Subject.InsertMany(_basicList); + await Subject.InsertManyAsync(_basicList); _basicList.ForEach(x => x.Interval = 999); - Subject.UpdateMany(_basicList); - Subject.All().Should().BeEquivalentTo(_basicList); + await Subject.UpdateManyAsync(_basicList); + (await Subject.AllAsync().ToListAsync()).Should().BeEquivalentTo(_basicList); } [Test] - public void update_many_should_throw_if_id_zero() + public async Task update_many_should_throw_if_id_zero() { - Subject.InsertMany(_basicList); + await Subject.InsertManyAsync(_basicList); _basicList[1].Id = 0; - Assert.Throws(() => Subject.UpdateMany(_basicList)); + Assert.ThrowsAsync(() => Subject.UpdateManyAsync(_basicList)); } [Test] - public void should_be_able_to_update_many_single_field() + public async Task should_be_able_to_update_many_single_field() { - Subject.InsertMany(_basicList); + await Subject.InsertManyAsync(_basicList); var executionBackup = _basicList.Select(x => x.LastExecution).ToList(); _basicList.ForEach(x => x.Interval = 999); _basicList.ForEach(x => x.LastExecution = DateTime.UtcNow); - Subject.SetFields(_basicList, x => x.Interval); + await Subject.SetFieldsAsync(_basicList, x => x.Interval); for (var i = 0; i < _basicList.Count; i++) { _basicList[i].LastExecution = executionBackup[i]; } - Subject.All().Should().BeEquivalentTo(_basicList); + (await Subject.AllAsync().ToListAsync()).Should().BeEquivalentTo(_basicList); } [Test] - public void set_fields_should_throw_if_any_id_zero() + public async Task set_fields_should_throw_if_any_id_zero() { - Subject.InsertMany(_basicList); + await Subject.InsertManyAsync(_basicList); _basicList.ForEach(x => x.Interval = 999); _basicList[1].Id = 0; - Assert.Throws(() => Subject.SetFields(_basicList, x => x.Interval)); + Assert.ThrowsAsync(() => Subject.SetFieldsAsync(_basicList, x => x.Interval)); } [Test] - public void should_be_able_to_delete_many_by_model() + public async Task should_be_able_to_delete_many_by_model() { - Subject.InsertMany(_basicList); - Subject.All().Should().HaveCount(5); + await Subject.InsertManyAsync(_basicList); + (await Subject.AllAsync().ToListAsync()).Should().HaveCount(5); - Subject.DeleteMany(_basicList.Take(2).ToList()); - Subject.All().Select(x => x.Id).Should().BeEquivalentTo(_basicList.Skip(2).Select(x => x.Id)); + await Subject.DeleteManyAsync(_basicList.Take(2).ToList()); + (await Subject.AllAsync().ToListAsync()).Select(x => x.Id).Should().BeEquivalentTo(_basicList.Skip(2).Select(x => x.Id)); } [Test] - public void should_be_able_to_delete_many_by_id() + public async Task should_be_able_to_delete_many_by_id() { - Subject.InsertMany(_basicList); - Subject.All().Should().HaveCount(5); + await Subject.InsertManyAsync(_basicList); + (await Subject.AllAsync().ToListAsync()).Should().HaveCount(5); - Subject.DeleteMany(_basicList.Take(2).Select(x => x.Id).ToList()); - Subject.All().Select(x => x.Id).Should().BeEquivalentTo(_basicList.Skip(2).Select(x => x.Id)); + await Subject.DeleteManyAsync(_basicList.Take(2).Select(x => x.Id).ToList()); + (await Subject.AllAsync().ToListAsync()).Select(x => x.Id).Should().BeEquivalentTo(_basicList.Skip(2).Select(x => x.Id)); } [Test] - public void purge_should_delete_all() + public async Task purge_should_delete_all() { - Subject.InsertMany(_basicList); + await Subject.InsertManyAsync(_basicList); AllStoredModels.Should().HaveCount(5); - Subject.Purge(); + await Subject.PurgeAsync(); AllStoredModels.Should().BeEmpty(); } [Test] - public void has_items_should_return_false_with_no_items() + public async Task has_items_should_return_false_with_no_items() { - Subject.HasItems().Should().BeFalse(); + (await Subject.HasItemsAsync()).Should().BeFalse(); } [Test] - public void has_items_should_return_true_with_items() + public async Task has_items_should_return_true_with_items() { - Subject.InsertMany(_basicList); - Subject.HasItems().Should().BeTrue(); + await Subject.InsertManyAsync(_basicList); + (await Subject.HasItemsAsync()).Should().BeTrue(); } [Test] public void single_should_throw_on_empty() { - Assert.Throws(() => Subject.Single()); + Assert.ThrowsAsync(() => Subject.SingleAsync()); } [Test] - public void should_be_able_to_get_single() + public async Task should_be_able_to_get_single() { - Subject.Insert(_basicList[0]); - Subject.Single().Should().BeEquivalentTo(_basicList[0]); + await Subject.InsertAsync(_basicList[0]); + (await Subject.SingleAsync()).Should().BeEquivalentTo(_basicList[0]); } [Test] - public void single_or_default_on_empty_table_should_return_null() + public async Task single_or_default_on_empty_table_should_return_null() { - Subject.SingleOrDefault().Should().BeNull(); + (await Subject.SingleOrDefaultAsync()).Should().BeNull(); } [Test] public void getting_model_with_invalid_id_should_throw() { - Assert.Throws(() => Subject.Get(12)); + Assert.ThrowsAsync(() => Subject.GetAsync(12)); } [Test] - public void get_all_with_empty_db_should_return_empty_list() + public async Task get_all_with_empty_db_should_return_empty_list() { - Subject.All().Should().BeEmpty(); + (await Subject.AllAsync().ToListAsync()).Should().BeEmpty(); } [Test] - public void should_be_able_to_call_ToList_on_empty_queryable() + public async Task should_be_able_to_call_ToList_on_empty_queryable() { - Subject.All().ToList().Should().BeEmpty(); + (await Subject.AllAsync().ToListAsync()).Should().BeEmpty(); } [TestCase(1, 2)] [TestCase(2, 2)] [TestCase(3, 1)] - public void get_paged_should_work(int page, int count) + public async Task get_paged_should_work(int page, int count) { - Subject.InsertMany(_basicList); + await Subject.InsertManyAsync(_basicList); var data = Subject.GetPaged(new PagingSpec() { Page = page, PageSize = 2, SortKey = "LastExecution", SortDirection = SortDirection.Descending }); data.Page.Should().Be(page); @@ -315,9 +316,9 @@ public void get_paged_should_work(int page, int count) [TestCase(1, 2)] [TestCase(2, 2)] [TestCase(3, 1)] - public void get_paged_should_work_with_null_sort_key(int page, int count) + public async Task get_paged_should_work_with_null_sort_key(int page, int count) { - Subject.InsertMany(_basicList); + await Subject.InsertManyAsync(_basicList); var data = Subject.GetPaged(new PagingSpec() { Page = page, PageSize = 2, SortDirection = SortDirection.Descending }); data.Page.Should().Be(page); diff --git a/src/NzbDrone.Core/Datastore/BasicRepository.cs b/src/NzbDrone.Core/Datastore/BasicRepository.cs index cae2e8302..e28bf8f47 100644 --- a/src/NzbDrone.Core/Datastore/BasicRepository.cs +++ b/src/NzbDrone.Core/Datastore/BasicRepository.cs @@ -5,11 +5,15 @@ using System.Linq; using System.Linq.Expressions; using System.Reflection; +using System.Runtime.CompilerServices; using System.Text; +using System.Threading; +using System.Threading.Tasks; using Dapper; using NLog; using NzbDrone.Common.Instrumentation; using NzbDrone.Core.Datastore.Events; +using NzbDrone.Core.Datastore.Extensions; using NzbDrone.Core.Messaging.Events; using Polly; using Polly.Retry; @@ -20,25 +24,45 @@ public interface IBasicRepository where TModel : ModelBase, new() { IEnumerable All(); + IAsyncEnumerable AllAsync(CancellationToken cancellationToken = default); int Count(); + Task CountAsync(CancellationToken cancellationToken = default); TModel Find(int id); + Task FindAsync(int id, CancellationToken cancellationToken = default); TModel Get(int id); + Task GetAsync(int id, CancellationToken cancellationToken = default); TModel Insert(TModel model); + Task InsertAsync(TModel model, CancellationToken cancellationToken = default); TModel Update(TModel model); + Task UpdateAsync(TModel model, CancellationToken cancellationToken = default); TModel Upsert(TModel model); + Task UpsertAsync(TModel model, CancellationToken cancellationToken = default); void SetFields(TModel model, params Expression>[] properties); + Task SetFieldsAsync(TModel model, params Expression>[] properties); void Delete(TModel model); + Task DeleteAsync(TModel model, CancellationToken cancellationToken = default); void Delete(int id); + Task DeleteAsync(int id, CancellationToken cancellationToken = default); IEnumerable Get(IEnumerable ids); + IAsyncEnumerable GetAsync(IEnumerable ids, CancellationToken cancellationToken = default); void InsertMany(IList model); + Task InsertManyAsync(IList models, CancellationToken cancellationToken = default); void UpdateMany(IList model); + Task UpdateManyAsync(IList models, CancellationToken cancellationToken = default); void SetFields(IList models, params Expression>[] properties); + Task SetFieldsAsync(IList models, params Expression>[] properties); void DeleteMany(List model); + Task DeleteManyAsync(List models, CancellationToken cancellationToken = default); void DeleteMany(IEnumerable ids); + Task DeleteManyAsync(IEnumerable ids, CancellationToken cancellationToken = default); void Purge(bool vacuum = false); + Task PurgeAsync(bool vacuum = false, CancellationToken cancellationToken = default); bool HasItems(); + Task HasItemsAsync(CancellationToken cancellationToken = default); TModel Single(); + Task SingleAsync(CancellationToken cancellationToken = default); TModel SingleOrDefault(); + Task SingleOrDefaultAsync(CancellationToken cancellationToken = default); PagingSpec GetPaged(PagingSpec pagingSpec); } @@ -95,10 +119,16 @@ public BasicRepository(IDatabase database, IEventAggregator eventAggregator) protected virtual List Query(SqlBuilder builder) => _database.Query(builder).ToList(); + protected virtual IAsyncEnumerable QueryAsync(SqlBuilder builder, CancellationToken cancellationToken = default) => _database.QueryAsync(builder, cancellationToken); + protected virtual List QueryDistinct(SqlBuilder builder) => _database.QueryDistinct(builder).ToList(); + protected virtual IAsyncEnumerable QueryDistinctAsync(SqlBuilder builder, CancellationToken cancellationToken = default) => _database.QueryDistinctAsync(builder, cancellationToken); + protected List Query(Expression> where) => Query(Builder().Where(where)); + protected IAsyncEnumerable QueryAsync(Expression> where, CancellationToken cancellationToken = default) => QueryAsync(Builder().Where(where), cancellationToken); + public int Count() { using (var conn = _database.OpenConnection()) @@ -107,11 +137,25 @@ public int Count() } } + public async Task CountAsync(CancellationToken cancellationToken = default) + { + cancellationToken.ThrowIfCancellationRequested(); + + await using var conn = await _database.OpenConnectionAsync(cancellationToken); + + return await conn.ExecuteScalarAsync($"SELECT COUNT(*) FROM \"{_table}\""); + } + public virtual IEnumerable All() { return Query(Builder()); } + public virtual IAsyncEnumerable AllAsync(CancellationToken cancellationToken = default) + { + return QueryAsync(Builder(), cancellationToken); + } + public TModel Find(int id) { var model = Query(x => x.Id == id).FirstOrDefault(); @@ -119,6 +163,11 @@ public TModel Find(int id) return model; } + public async Task FindAsync(int id, CancellationToken cancellationToken = default) + { + return await QueryAsync(x => x.Id == id, cancellationToken).FirstOrDefaultAsync(cancellationToken); + } + public TModel Get(int id) { var model = Find(id); @@ -131,6 +180,18 @@ public TModel Get(int id) return model; } + public async Task GetAsync(int id, CancellationToken cancellationToken = default) + { + var model = await FindAsync(id, cancellationToken); + + if (model == null) + { + throw new ModelNotFoundException(typeof(TModel), id); + } + + return model; + } + public IEnumerable Get(IEnumerable ids) { if (!ids.Any()) @@ -148,16 +209,47 @@ public IEnumerable Get(IEnumerable ids) return result; } + public async IAsyncEnumerable GetAsync(IEnumerable ids, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + if (!ids.Any()) + { + yield break; + } + + var result = QueryAsync(x => ids.Contains(x.Id), cancellationToken); + var countResult = await result.CountAsync(cancellationToken); + + if (countResult != ids.Count()) + { + throw new ApplicationException($"Expected query to return {ids.Count()} rows but returned {countResult}"); + } + + await foreach (var model in result) + { + yield return model; + } + } + public TModel SingleOrDefault() { return All().SingleOrDefault(); } + public async Task SingleOrDefaultAsync(CancellationToken cancellationToken = default) + { + return await AllAsync(cancellationToken).SingleOrDefaultAsync(cancellationToken); + } + public TModel Single() { return All().Single(); } + public async Task SingleAsync(CancellationToken cancellationToken = default) + { + return await AllAsync(cancellationToken).SingleAsync(cancellationToken); + } + public TModel Insert(TModel model) { if (model.Id != 0) @@ -175,6 +267,23 @@ public TModel Insert(TModel model) return model; } + public async Task InsertAsync(TModel model, CancellationToken cancellationToken = default) + { + if (model.Id != 0) + { + throw new InvalidOperationException("Can't insert model with existing ID " + model.Id); + } + + await using (var conn = await _database.OpenConnectionAsync(cancellationToken)) + { + model = await InsertAsync(conn, null, model, cancellationToken); + } + + ModelCreated(model); + + return model; + } + private string GetInsertSql() { var sbColumnList = new StringBuilder(null); @@ -220,6 +329,19 @@ private TModel Insert(IDbConnection connection, IDbTransaction transaction, TMod return model; } + private async Task InsertAsync(IDbConnection connection, IDbTransaction transaction, TModel model, CancellationToken cancellationToken = default) + { + SqlBuilderExtensions.LogQuery(_insertSql, model); + + var multi = await RetryStrategy.ExecuteAsync(async static (state, _) => await state.connection.QueryMultipleAsync(state._insertSql, state.model, state.transaction), (connection, _insertSql, model, transaction), cancellationToken); + + var multiRead = await multi.ReadAsync(); + var id = (int)(multiRead.First().id ?? multiRead.First().Id); + _keyProperty.SetValue(model, id); + + return model; + } + public void InsertMany(IList models) { if (models.Any(x => x.Id != 0)) @@ -241,6 +363,24 @@ public void InsertMany(IList models) } } + public async Task InsertManyAsync(IList models, CancellationToken cancellationToken = default) + { + if (models.Any(x => x.Id != 0)) + { + throw new InvalidOperationException("Can't insert model with existing ID != 0"); + } + + await using var conn = await _database.OpenConnectionAsync(cancellationToken); + await using var tran = await conn.BeginTransactionAsync(IsolationLevel.ReadCommitted, cancellationToken); + + foreach (var model in models) + { + await InsertAsync(conn, tran, model, cancellationToken); + } + + await tran.CommitAsync(cancellationToken); + } + public TModel Update(TModel model) { if (model.Id == 0) @@ -258,6 +398,23 @@ public TModel Update(TModel model) return model; } + public async Task UpdateAsync(TModel model, CancellationToken cancellationToken = default) + { + if (model.Id == 0) + { + throw new InvalidOperationException("Can't update model with ID 0"); + } + + await using (var conn = await _database.OpenConnectionAsync(cancellationToken)) + { + await UpdateFieldsAsync(conn, null, model, _properties, cancellationToken); + } + + ModelUpdated(model); + + return model; + } + public void UpdateMany(IList models) { if (models.Any(x => x.Id == 0)) @@ -273,11 +430,31 @@ public void UpdateMany(IList models) } } + public async Task UpdateManyAsync(IList models, CancellationToken cancellationToken = default) + { + if (models.Any(x => x.Id == 0)) + { + throw new InvalidOperationException("Can't update model with ID 0"); + } + + await using var conn = await _database.OpenConnectionAsync(cancellationToken); + await using var tran = await conn.BeginTransactionAsync(IsolationLevel.ReadCommitted, cancellationToken); + + await UpdateFieldsAsync(conn, tran, models, _properties, cancellationToken); + + await tran.CommitAsync(cancellationToken); + } + protected void Delete(Expression> where) { Delete(Builder().Where(where)); } + protected async Task DeleteAsync(Expression> where, CancellationToken cancellationToken = default) + { + await DeleteAsync(Builder().Where(where), cancellationToken); + } + protected void Delete(SqlBuilder builder) { var sql = builder.AddDeleteTemplate(typeof(TModel)); @@ -288,16 +465,37 @@ protected void Delete(SqlBuilder builder) } } + protected async Task DeleteAsync(SqlBuilder builder, CancellationToken cancellationToken = default) + { + cancellationToken.ThrowIfCancellationRequested(); + + var sql = builder.AddDeleteTemplate(typeof(TModel)); + + await using var conn = await _database.OpenConnectionAsync(cancellationToken); + + await conn.ExecuteAsync(sql.RawSql, sql.Parameters); + } + public void Delete(TModel model) { Delete(model.Id); } + public async Task DeleteAsync(TModel model, CancellationToken cancellationToken = default) + { + await DeleteAsync(model.Id, cancellationToken); + } + public void Delete(int id) { Delete(x => x.Id == id); } + public async Task DeleteAsync(int id, CancellationToken cancellationToken = default) + { + await DeleteAsync(x => x.Id == id, cancellationToken); + } + public void DeleteMany(IEnumerable ids) { if (ids.Any()) @@ -306,11 +504,24 @@ public void DeleteMany(IEnumerable ids) } } + public async Task DeleteManyAsync(IEnumerable ids, CancellationToken cancellationToken = default) + { + if (ids.Any()) + { + await DeleteAsync(x => ids.Contains(x.Id), cancellationToken); + } + } + public void DeleteMany(List models) { DeleteMany(models.Select(m => m.Id)); } + public async Task DeleteManyAsync(List models, CancellationToken cancellationToken = default) + { + await DeleteManyAsync(models.Select(m => m.Id), cancellationToken); + } + public TModel Upsert(TModel model) { if (model.Id == 0) @@ -323,6 +534,20 @@ public TModel Upsert(TModel model) return model; } + public async Task UpsertAsync(TModel model, CancellationToken cancellationToken = default) + { + cancellationToken.ThrowIfCancellationRequested(); + + if (model.Id == 0) + { + await InsertAsync(model, cancellationToken); + return model; + } + + await UpdateAsync(model, cancellationToken); + return model; + } + public void Purge(bool vacuum = false) { using (var conn = _database.OpenConnection()) @@ -336,6 +561,21 @@ public void Purge(bool vacuum = false) } } + public async Task PurgeAsync(bool vacuum = false, CancellationToken cancellationToken = default) + { + cancellationToken.ThrowIfCancellationRequested(); + + await using (var conn = await _database.OpenConnectionAsync(cancellationToken)) + { + await conn.ExecuteAsync($"DELETE FROM \"{_table}\""); + } + + if (vacuum) + { + Vacuum(); + } + } + protected void Vacuum() { _database.Vacuum(); @@ -346,6 +586,11 @@ public bool HasItems() return Count() > 0; } + public async Task HasItemsAsync(CancellationToken cancellationToken = default) + { + return await CountAsync(cancellationToken) > 0; + } + public void SetFields(TModel model, params Expression>[] properties) { if (model.Id == 0) @@ -363,6 +608,23 @@ public void SetFields(TModel model, params Expression>[] pr ModelUpdated(model); } + public async Task SetFieldsAsync(TModel model, params Expression>[] properties) + { + if (model.Id == 0) + { + throw new InvalidOperationException("Attempted to update model without ID"); + } + + var propertiesToUpdate = properties.Select(x => x.GetMemberName()).ToList(); + + await using (var conn = await _database.OpenConnectionAsync()) + { + await UpdateFieldsAsync(conn, null, model, propertiesToUpdate); + } + + ModelUpdated(model); + } + public void SetFields(IList models, params Expression>[] properties) { if (models.Any(x => x.Id == 0)) @@ -385,6 +647,28 @@ public void SetFields(IList models, params Expression models, params Expression>[] properties) + { + if (models.Any(x => x.Id == 0)) + { + throw new InvalidOperationException("Attempted to update model without ID"); + } + + var propertiesToUpdate = properties.Select(x => x.GetMemberName()).ToList(); + + await using (var conn = await _database.OpenConnectionAsync()) + await using (var tran = await conn.BeginTransactionAsync(IsolationLevel.ReadCommitted)) + { + await UpdateFieldsAsync(conn, tran, models, propertiesToUpdate); + await tran.CommitAsync(); + } + + foreach (var model in models) + { + ModelUpdated(model); + } + } + private string GetUpdateSql(List propertiesToUpdate) { var sb = new StringBuilder(); @@ -414,6 +698,17 @@ private void UpdateFields(IDbConnection connection, IDbTransaction transaction, RetryStrategy.Execute(static (state, _) => state.connection.Execute(state.sql, state.model, transaction: state.transaction), (connection, sql, model, transaction)); } + private async Task UpdateFieldsAsync(IDbConnection connection, IDbTransaction transaction, TModel model, List propertiesToUpdate, CancellationToken cancellationToken = default) + { + cancellationToken.ThrowIfCancellationRequested(); + + var sql = propertiesToUpdate == _properties ? _updateSql : GetUpdateSql(propertiesToUpdate); + + SqlBuilderExtensions.LogQuery(sql, model); + + await RetryStrategy.ExecuteAsync(async static (state, _) => await state.connection.ExecuteAsync(state.sql, state.model, transaction: state.transaction), (connection, sql, model, transaction), cancellationToken); + } + private void UpdateFields(IDbConnection connection, IDbTransaction transaction, IList models, List propertiesToUpdate) { var sql = propertiesToUpdate == _properties ? _updateSql : GetUpdateSql(propertiesToUpdate); @@ -426,6 +721,20 @@ private void UpdateFields(IDbConnection connection, IDbTransaction transaction, RetryStrategy.Execute(static (state, _) => state.connection.Execute(state.sql, state.models, transaction: state.transaction), (connection, sql, models, transaction)); } + private async Task UpdateFieldsAsync(IDbConnection connection, IDbTransaction transaction, IList models, List propertiesToUpdate, CancellationToken cancellationToken = default) + { + cancellationToken.ThrowIfCancellationRequested(); + + var sql = propertiesToUpdate == _properties ? _updateSql : GetUpdateSql(propertiesToUpdate); + + foreach (var model in models) + { + SqlBuilderExtensions.LogQuery(sql, model); + } + + await RetryStrategy.ExecuteAsync(async static (state, _) => await state.connection.ExecuteAsync(state.sql, state.models, transaction: state.transaction), (connection, sql, models, transaction), cancellationToken); + } + protected virtual SqlBuilder PagedBuilder() => Builder(); protected virtual IEnumerable PagedQuery(SqlBuilder sql) => Query(sql); diff --git a/src/NzbDrone.Core/Datastore/Database.cs b/src/NzbDrone.Core/Datastore/Database.cs index 02e687552..107fa8ee9 100644 --- a/src/NzbDrone.Core/Datastore/Database.cs +++ b/src/NzbDrone.Core/Datastore/Database.cs @@ -1,6 +1,8 @@ using System; using System.Data.Common; using System.Data.SQLite; +using System.Threading; +using System.Threading.Tasks; using Dapper; using NLog; using NzbDrone.Common.Instrumentation; @@ -10,23 +12,27 @@ namespace NzbDrone.Core.Datastore public interface IDatabase { DbConnection OpenConnection(); + Task OpenConnectionAsync(CancellationToken cancellationToken = default); Version Version { get; } int Migration { get; } DatabaseType DatabaseType { get; } void Vacuum(); + Task VacuumAsync(CancellationToken cancellationToken = default); } public class Database : IDatabase { private readonly string _databaseName; private readonly Func _datamapperFactory; + private readonly Func> _datamapperFactoryAsync; private readonly Logger _logger = NzbDroneLogger.GetLogger(typeof(Database)); - public Database(string databaseName, Func datamapperFactory) + public Database(string databaseName, Func datamapperFactory, Func> datamapperFactoryAsync) { _databaseName = databaseName; _datamapperFactory = datamapperFactory; + _datamapperFactoryAsync = datamapperFactoryAsync; } public DbConnection OpenConnection() @@ -34,6 +40,11 @@ public DbConnection OpenConnection() return _datamapperFactory(); } + public Task OpenConnectionAsync(CancellationToken cancellationToken = default) + { + return _datamapperFactoryAsync(cancellationToken); + } + public DatabaseType DatabaseType { get @@ -83,6 +94,24 @@ public void Vacuum() _logger.Error(e, "An Error occurred while vacuuming database."); } } + + public async Task VacuumAsync(CancellationToken cancellationToken = default) + { + try + { + _logger.Info("Vacuuming {0} database", _databaseName); + await using (var db = await _datamapperFactoryAsync(cancellationToken)) + { + await db.ExecuteAsync("Vacuum;"); + } + + _logger.Info("{0} database compressed", _databaseName); + } + catch (Exception e) + { + _logger.Error(e, "An Error occurred while vacuuming database."); + } + } } public enum DatabaseType diff --git a/src/NzbDrone.Core/Datastore/DbFactory.cs b/src/NzbDrone.Core/Datastore/DbFactory.cs index 7901aa27b..481c5c5af 100644 --- a/src/NzbDrone.Core/Datastore/DbFactory.cs +++ b/src/NzbDrone.Core/Datastore/DbFactory.cs @@ -87,23 +87,42 @@ public IDatabase Create(MigrationContext migrationContext) } } - var db = new Database(migrationContext.MigrationType.ToString(), () => - { - DbConnection conn; - - if (connectionInfo.DatabaseType == DatabaseType.SQLite) + var db = new Database( + migrationContext.MigrationType.ToString(), + () => { - conn = SQLiteFactory.Instance.CreateConnection(); - conn.ConnectionString = connectionInfo.ConnectionString; - } - else - { - conn = new NpgsqlConnection(connectionInfo.ConnectionString); - } + DbConnection conn; - conn.Open(); - return conn; - }); + if (connectionInfo.DatabaseType == DatabaseType.SQLite) + { + conn = SQLiteFactory.Instance.CreateConnection(); + conn.ConnectionString = connectionInfo.ConnectionString; + } + else + { + conn = new NpgsqlConnection(connectionInfo.ConnectionString); + } + + conn.Open(); + return conn; + }, + async cancellationToken => + { + DbConnection conn; + + if (connectionInfo.DatabaseType == DatabaseType.SQLite) + { + conn = SQLiteFactory.Instance.CreateConnection(); + conn.ConnectionString = connectionInfo.ConnectionString; + } + else + { + conn = new NpgsqlConnection(connectionInfo.ConnectionString); + } + + await conn.OpenAsync(cancellationToken); + return conn; + }); return db; } diff --git a/src/NzbDrone.Core/Datastore/Extensions/SqlMapperAsyncExtensions.cs b/src/NzbDrone.Core/Datastore/Extensions/SqlMapperAsyncExtensions.cs new file mode 100644 index 000000000..9754cc085 --- /dev/null +++ b/src/NzbDrone.Core/Datastore/Extensions/SqlMapperAsyncExtensions.cs @@ -0,0 +1,204 @@ +using System; +using System.Collections.Generic; +using System.Data; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; +using Dapper; + +namespace NzbDrone.Core.Datastore.Extensions +{ + public static class SqlMapperAsyncExtensions + { + public static async IAsyncEnumerable QueryAsync(this IDatabase db, string sql, object param = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + await using var conn = await db.OpenConnectionAsync(cancellationToken); + + IAsyncEnumerable items; + try + { + items = conn.QueryUnbufferedAsync(sql, param); + } + catch (Exception e) + { + e.Data.Add("SQL", SqlBuilderExtensions.GetSqlLogString(sql, param)); + throw; + } + + if (TableMapping.Mapper.LazyLoadList.TryGetValue(typeof(T), out var lazyProperties)) + { + await foreach (var item in items.WithCancellation(cancellationToken)) + { + ApplyLazyLoad(db, item, lazyProperties); + + yield return item; + } + } + } + + public static async Task> QueryAsync(this IDatabase db, string sql, Func map, object param = null, IDbTransaction transaction = null, bool buffered = true, string splitOn = "Id", int? commandTimeout = null, CommandType? commandType = null, CancellationToken cancellationToken = default) + { + TReturn MapWithLazy(TFirst first, TSecond second) + { + ApplyLazyLoad(db, first); + ApplyLazyLoad(db, second); + return map(first, second); + } + + await using var conn = await db.OpenConnectionAsync(cancellationToken); + + try + { + return await conn.QueryAsync(sql, MapWithLazy, param, transaction, buffered, splitOn, commandTimeout, commandType); + } + catch (Exception e) + { + e.Data.Add("SQL", SqlBuilderExtensions.GetSqlLogString(sql, param)); + throw; + } + } + + public static async Task> QueryAsync(this IDatabase db, string sql, Func map, object param = null, IDbTransaction transaction = null, bool buffered = true, string splitOn = "Id", int? commandTimeout = null, CommandType? commandType = null, CancellationToken cancellationToken = default) + { + TReturn MapWithLazy(TFirst first, TSecond second, TThird third) + { + ApplyLazyLoad(db, first); + ApplyLazyLoad(db, second); + ApplyLazyLoad(db, third); + return map(first, second, third); + } + + await using var conn = await db.OpenConnectionAsync(cancellationToken); + + try + { + return await conn.QueryAsync(sql, MapWithLazy, param, transaction, buffered, splitOn, commandTimeout, commandType); + } + catch (Exception e) + { + e.Data.Add("SQL", SqlBuilderExtensions.GetSqlLogString(sql, param)); + throw; + } + } + + public static async Task> QueryAsync(this IDatabase db, string sql, Func map, object param = null, IDbTransaction transaction = null, bool buffered = true, string splitOn = "Id", int? commandTimeout = null, CommandType? commandType = null, CancellationToken cancellationToken = default) + { + TReturn MapWithLazy(TFirst first, TSecond second, TThird third, TFourth fourth) + { + ApplyLazyLoad(db, first); + ApplyLazyLoad(db, second); + ApplyLazyLoad(db, third); + ApplyLazyLoad(db, fourth); + return map(first, second, third, fourth); + } + + await using var conn = await db.OpenConnectionAsync(cancellationToken); + + try + { + return await conn.QueryAsync(sql, MapWithLazy, param, transaction, buffered, splitOn, commandTimeout, commandType); + } + catch (Exception e) + { + e.Data.Add("SQL", SqlBuilderExtensions.GetSqlLogString(sql, param)); + throw; + } + } + + public static async Task> QueryAsync(this IDatabase db, string sql, Func map, object param = null, IDbTransaction transaction = null, bool buffered = true, string splitOn = "Id", int? commandTimeout = null, CommandType? commandType = null, CancellationToken cancellationToken = default) + { + TReturn MapWithLazy(TFirst first, TSecond second, TThird third, TFourth fourth, TFifth fifth) + { + ApplyLazyLoad(db, first); + ApplyLazyLoad(db, second); + ApplyLazyLoad(db, third); + ApplyLazyLoad(db, fourth); + ApplyLazyLoad(db, fifth); + return map(first, second, third, fourth, fifth); + } + + await using var conn = await db.OpenConnectionAsync(cancellationToken); + + try + { + return await conn.QueryAsync(sql, MapWithLazy, param, transaction, buffered, splitOn, commandTimeout, commandType); + } + catch (Exception e) + { + e.Data.Add("SQL", SqlBuilderExtensions.GetSqlLogString(sql, param)); + throw; + } + } + + public static IAsyncEnumerable QueryAsync(this IDatabase db, SqlBuilder builder, CancellationToken cancellationToken = default) + { + var type = typeof(T); + var sql = builder.Select(type).AddSelectTemplate(type); + + return db.QueryAsync(sql.RawSql, sql.Parameters, cancellationToken: cancellationToken); + } + + public static IAsyncEnumerable QueryDistinctAsync(this IDatabase db, SqlBuilder builder, CancellationToken cancellationToken = default) + { + var type = typeof(T); + var sql = builder.SelectDistinct(type).AddSelectTemplate(type); + + return db.QueryAsync(sql.RawSql, sql.Parameters, cancellationToken: cancellationToken); + } + + public static async Task> QueryJoinedAsync(this IDatabase db, SqlBuilder builder, Func mapper) + { + var type = typeof(T); + var sql = builder.Select(type, typeof(T2)).AddSelectTemplate(type); + + return await db.QueryAsync(sql.RawSql, mapper, sql.Parameters); + } + + public static async Task> QueryJoinedAsync(this IDatabase db, SqlBuilder builder, Func mapper) + { + var type = typeof(T); + var sql = builder.Select(type, typeof(T2), typeof(T3)).AddSelectTemplate(type); + + return await db.QueryAsync(sql.RawSql, mapper, sql.Parameters); + } + + public static async Task> QueryJoinedAsync(this IDatabase db, SqlBuilder builder, Func mapper) + { + var type = typeof(T); + var sql = builder.Select(type, typeof(T2), typeof(T3), typeof(T4)).AddSelectTemplate(type); + + return await db.QueryAsync(sql.RawSql, mapper, sql.Parameters); + } + + public static async Task> QueryJoinedAsync(this IDatabase db, SqlBuilder builder, Func mapper) + { + var type = typeof(T); + var sql = builder.Select(type, typeof(T2), typeof(T3), typeof(T4), typeof(T5)).AddSelectTemplate(type); + + return await db.QueryAsync(sql.RawSql, mapper, sql.Parameters); + } + + private static void ApplyLazyLoad(IDatabase db, TModel model) + { + if (TableMapping.Mapper.LazyLoadList.TryGetValue(typeof(TModel), out var lazyProperties)) + { + ApplyLazyLoad(db, model, lazyProperties); + } + } + + private static void ApplyLazyLoad(IDatabase db, TModel model, List lazyProperties) + { + if (model == null) + { + return; + } + + foreach (var lazyProperty in lazyProperties) + { + var lazy = (ILazyLoaded)lazyProperty.LazyLoad.Clone(); + lazy.Prepare(db, model); + lazyProperty.Property.SetValue(model, lazy); + } + } + } +} diff --git a/src/NzbDrone.Core/Datastore/LogDatabase.cs b/src/NzbDrone.Core/Datastore/LogDatabase.cs index f996986d8..26e045cce 100644 --- a/src/NzbDrone.Core/Datastore/LogDatabase.cs +++ b/src/NzbDrone.Core/Datastore/LogDatabase.cs @@ -1,5 +1,7 @@ using System; using System.Data.Common; +using System.Threading; +using System.Threading.Tasks; namespace NzbDrone.Core.Datastore { @@ -23,6 +25,11 @@ public DbConnection OpenConnection() return _database.OpenConnection(); } + public async Task OpenConnectionAsync(CancellationToken cancellationToken = default) + { + return await _database.OpenConnectionAsync(cancellationToken); + } + public Version Version => _database.Version; public int Migration => _database.Migration; @@ -33,5 +40,10 @@ public void Vacuum() { _database.Vacuum(); } + + public async Task VacuumAsync(CancellationToken cancellationToken = default) + { + await _database.VacuumAsync(cancellationToken); + } } } diff --git a/src/NzbDrone.Core/Datastore/MainDatabase.cs b/src/NzbDrone.Core/Datastore/MainDatabase.cs index 7e39b1356..5b34bf8e3 100644 --- a/src/NzbDrone.Core/Datastore/MainDatabase.cs +++ b/src/NzbDrone.Core/Datastore/MainDatabase.cs @@ -1,5 +1,7 @@ using System; using System.Data.Common; +using System.Threading; +using System.Threading.Tasks; using StackExchange.Profiling; using StackExchange.Profiling.Data; @@ -32,6 +34,18 @@ public DbConnection OpenConnection() return new ProfiledDbConnection(connection, MiniProfiler.Current); } + public async Task OpenConnectionAsync(CancellationToken cancellationToken = default) + { + var connection = await _database.OpenConnectionAsync(cancellationToken); + + if (_databaseType == DatabaseType.PostgreSQL) + { + return new ProfiledImplementations.NpgSqlConnection(connection, MiniProfiler.Current); + } + + return new ProfiledDbConnection(connection, MiniProfiler.Current); + } + public Version Version => _database.Version; public int Migration => _database.Migration; @@ -42,5 +56,10 @@ public void Vacuum() { _database.Vacuum(); } + + public async Task VacuumAsync(CancellationToken cancellationToken = default) + { + await _database.VacuumAsync(cancellationToken); + } } }