Add async overloaded methods for BasicRepository

This commit is contained in:
Bogdan 2026-04-27 11:57:18 +03:00
parent d0504c2790
commit e965f22518
7 changed files with 705 additions and 112 deletions

View file

@ -1,6 +1,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using FizzWare.NBuilder;
using FluentAssertions;
using NUnit.Framework;
@ -33,277 +34,277 @@ public void Setup()
}
[Test]
public void should_be_able_to_insert()
public async Task should_be_able_to_insert()
{
Subject.Insert(_basicList[0]);
Subject.All().Should().HaveCount(1);
await Subject.InsertAsync(_basicList[0]);
(await Subject.AllAsync().ToListAsync()).Should().HaveCount(1);
}
[Test]
public void should_be_able_to_insert_many()
public async Task should_be_able_to_insert_many()
{
Subject.InsertMany(_basicList);
Subject.All().Should().HaveCount(5);
await Subject.InsertManyAsync(_basicList);
(await Subject.AllAsync().ToListAsync()).Should().HaveCount(5);
}
[Test]
public void insert_many_should_throw_if_id_not_zero()
{
_basicList[1].Id = 999;
Assert.Throws<InvalidOperationException>(() => Subject.InsertMany(_basicList));
Assert.ThrowsAsync<InvalidOperationException>(() => Subject.InsertManyAsync(_basicList));
}
[Test]
public void should_be_able_to_get_count()
public async Task should_be_able_to_get_count()
{
Subject.InsertMany(_basicList);
Subject.Count().Should().Be(_basicList.Count);
await Subject.InsertManyAsync(_basicList);
(await Subject.CountAsync()).Should().Be(_basicList.Count);
}
[Test]
public void should_be_able_to_find_by_id()
public async Task should_be_able_to_find_by_id()
{
Subject.InsertMany(_basicList);
var storeObject = Subject.Get(_basicList[1].Id);
await Subject.InsertManyAsync(_basicList);
var storeObject = await Subject.GetAsync(_basicList[1].Id);
storeObject.Should().BeEquivalentTo(_basicList[1], o => o.IncludingAllRuntimeProperties());
}
[Test]
public void should_be_able_to_update()
public async Task should_be_able_to_update()
{
Subject.InsertMany(_basicList);
await Subject.InsertManyAsync(_basicList);
var item = _basicList[1];
item.Interval = 999;
Subject.Update(item);
await Subject.UpdateAsync(item);
Subject.All().Should().BeEquivalentTo(_basicList);
(await Subject.AllAsync().ToListAsync()).Should().BeEquivalentTo(_basicList);
}
[Test]
public void should_be_able_to_upsert_new()
public async Task should_be_able_to_upsert_new()
{
Subject.Upsert(_basicList[0]);
Subject.All().Should().HaveCount(1);
await Subject.UpsertAsync(_basicList[0]);
(await Subject.AllAsync().ToListAsync()).Should().HaveCount(1);
}
[Test]
public void should_be_able_to_upsert_existing()
public async Task should_be_able_to_upsert_existing()
{
Subject.InsertMany(_basicList);
await Subject.InsertManyAsync(_basicList);
var item = _basicList[1];
item.Interval = 999;
Subject.Upsert(item);
await Subject.UpsertAsync(item);
Subject.All().Should().BeEquivalentTo(_basicList);
(await Subject.AllAsync().ToListAsync()).Should().BeEquivalentTo(_basicList);
}
[Test]
public void should_be_able_to_update_single_field()
public async Task should_be_able_to_update_single_field()
{
Subject.InsertMany(_basicList);
await Subject.InsertManyAsync(_basicList);
var item = _basicList[1];
var executionBackup = item.LastExecution;
item.Interval = 999;
item.LastExecution = DateTime.UtcNow;
Subject.SetFields(item, x => x.Interval);
await Subject.SetFieldsAsync(item, x => x.Interval);
item.LastExecution = executionBackup;
Subject.All().Should().BeEquivalentTo(_basicList);
(await Subject.AllAsync().ToListAsync()).Should().BeEquivalentTo(_basicList);
}
[Test]
public void set_fields_should_throw_if_id_zero()
public async Task set_fields_should_throw_if_id_zero()
{
Subject.InsertMany(_basicList);
await Subject.InsertManyAsync(_basicList);
_basicList[1].Id = 0;
_basicList[1].LastExecution = DateTime.UtcNow;
Assert.Throws<InvalidOperationException>(() => Subject.SetFields(_basicList[1], x => x.Interval));
Assert.ThrowsAsync<InvalidOperationException>(() => Subject.SetFieldsAsync(_basicList[1], x => x.Interval));
}
[Test]
public void should_be_able_to_delete_model_by_id()
public async Task should_be_able_to_delete_model_by_id()
{
Subject.InsertMany(_basicList);
Subject.All().Should().HaveCount(5);
await Subject.InsertManyAsync(_basicList);
(await Subject.AllAsync().ToListAsync()).Should().HaveCount(5);
Subject.Delete(_basicList[0].Id);
Subject.All().Select(x => x.Id).Should().BeEquivalentTo(_basicList.Skip(1).Select(x => x.Id));
await Subject.DeleteAsync(_basicList[0].Id);
(await Subject.AllAsync().ToListAsync()).Select(x => x.Id).Should().BeEquivalentTo(_basicList.Skip(1).Select(x => x.Id));
}
[Test]
public void should_be_able_to_delete_model_by_object()
public async Task should_be_able_to_delete_model_by_object()
{
Subject.InsertMany(_basicList);
Subject.All().Should().HaveCount(5);
await Subject.InsertManyAsync(_basicList);
(await Subject.AllAsync().ToListAsync()).Should().HaveCount(5);
Subject.Delete(_basicList[0]);
Subject.All().Select(x => x.Id).Should().BeEquivalentTo(_basicList.Skip(1).Select(x => x.Id));
await Subject.DeleteAsync(_basicList[0]);
(await Subject.AllAsync().ToListAsync()).Select(x => x.Id).Should().BeEquivalentTo(_basicList.Skip(1).Select(x => x.Id));
}
[Test]
public void get_many_should_return_empty_list_if_no_ids()
public async Task get_many_should_return_empty_list_if_no_ids()
{
Subject.Get(new List<int>()).Should().BeEquivalentTo(new List<ScheduledTask>());
(await Subject.GetAsync(new List<int>()).ToListAsync()).Should().BeEquivalentTo(new List<ScheduledTask>());
}
[Test]
public void get_many_should_throw_if_not_all_found()
public async Task get_many_should_throw_if_not_all_found()
{
Subject.InsertMany(_basicList);
Assert.Throws<ApplicationException>(() => Subject.Get(new[] { 999 }));
await Subject.InsertManyAsync(_basicList);
Assert.ThrowsAsync<ApplicationException>(async () => await Subject.GetAsync([999]).ToListAsync());
}
[Test]
public void should_be_able_to_find_by_multiple_id()
public async Task should_be_able_to_find_by_multiple_id()
{
Subject.InsertMany(_basicList);
var storeObject = Subject.Get(_basicList.Take(2).Select(x => x.Id));
await Subject.InsertManyAsync(_basicList);
var storeObject = await Subject.GetAsync(_basicList.Take(2).Select(x => x.Id)).ToListAsync();
storeObject.Select(x => x.Id).Should().BeEquivalentTo(_basicList.Take(2).Select(x => x.Id));
}
[Test]
public void should_be_able_to_update_many()
public async Task should_be_able_to_update_many()
{
Subject.InsertMany(_basicList);
await Subject.InsertManyAsync(_basicList);
_basicList.ForEach(x => x.Interval = 999);
Subject.UpdateMany(_basicList);
Subject.All().Should().BeEquivalentTo(_basicList);
await Subject.UpdateManyAsync(_basicList);
(await Subject.AllAsync().ToListAsync()).Should().BeEquivalentTo(_basicList);
}
[Test]
public void update_many_should_throw_if_id_zero()
public async Task update_many_should_throw_if_id_zero()
{
Subject.InsertMany(_basicList);
await Subject.InsertManyAsync(_basicList);
_basicList[1].Id = 0;
Assert.Throws<InvalidOperationException>(() => Subject.UpdateMany(_basicList));
Assert.ThrowsAsync<InvalidOperationException>(() => Subject.UpdateManyAsync(_basicList));
}
[Test]
public void should_be_able_to_update_many_single_field()
public async Task should_be_able_to_update_many_single_field()
{
Subject.InsertMany(_basicList);
await Subject.InsertManyAsync(_basicList);
var executionBackup = _basicList.Select(x => x.LastExecution).ToList();
_basicList.ForEach(x => x.Interval = 999);
_basicList.ForEach(x => x.LastExecution = DateTime.UtcNow);
Subject.SetFields(_basicList, x => x.Interval);
await Subject.SetFieldsAsync(_basicList, x => x.Interval);
for (var i = 0; i < _basicList.Count; i++)
{
_basicList[i].LastExecution = executionBackup[i];
}
Subject.All().Should().BeEquivalentTo(_basicList);
(await Subject.AllAsync().ToListAsync()).Should().BeEquivalentTo(_basicList);
}
[Test]
public void set_fields_should_throw_if_any_id_zero()
public async Task set_fields_should_throw_if_any_id_zero()
{
Subject.InsertMany(_basicList);
await Subject.InsertManyAsync(_basicList);
_basicList.ForEach(x => x.Interval = 999);
_basicList[1].Id = 0;
Assert.Throws<InvalidOperationException>(() => Subject.SetFields(_basicList, x => x.Interval));
Assert.ThrowsAsync<InvalidOperationException>(() => Subject.SetFieldsAsync(_basicList, x => x.Interval));
}
[Test]
public void should_be_able_to_delete_many_by_model()
public async Task should_be_able_to_delete_many_by_model()
{
Subject.InsertMany(_basicList);
Subject.All().Should().HaveCount(5);
await Subject.InsertManyAsync(_basicList);
(await Subject.AllAsync().ToListAsync()).Should().HaveCount(5);
Subject.DeleteMany(_basicList.Take(2).ToList());
Subject.All().Select(x => x.Id).Should().BeEquivalentTo(_basicList.Skip(2).Select(x => x.Id));
await Subject.DeleteManyAsync(_basicList.Take(2).ToList());
(await Subject.AllAsync().ToListAsync()).Select(x => x.Id).Should().BeEquivalentTo(_basicList.Skip(2).Select(x => x.Id));
}
[Test]
public void should_be_able_to_delete_many_by_id()
public async Task should_be_able_to_delete_many_by_id()
{
Subject.InsertMany(_basicList);
Subject.All().Should().HaveCount(5);
await Subject.InsertManyAsync(_basicList);
(await Subject.AllAsync().ToListAsync()).Should().HaveCount(5);
Subject.DeleteMany(_basicList.Take(2).Select(x => x.Id).ToList());
Subject.All().Select(x => x.Id).Should().BeEquivalentTo(_basicList.Skip(2).Select(x => x.Id));
await Subject.DeleteManyAsync(_basicList.Take(2).Select(x => x.Id).ToList());
(await Subject.AllAsync().ToListAsync()).Select(x => x.Id).Should().BeEquivalentTo(_basicList.Skip(2).Select(x => x.Id));
}
[Test]
public void purge_should_delete_all()
public async Task purge_should_delete_all()
{
Subject.InsertMany(_basicList);
await Subject.InsertManyAsync(_basicList);
AllStoredModels.Should().HaveCount(5);
Subject.Purge();
await Subject.PurgeAsync();
AllStoredModels.Should().BeEmpty();
}
[Test]
public void has_items_should_return_false_with_no_items()
public async Task has_items_should_return_false_with_no_items()
{
Subject.HasItems().Should().BeFalse();
(await Subject.HasItemsAsync()).Should().BeFalse();
}
[Test]
public void has_items_should_return_true_with_items()
public async Task has_items_should_return_true_with_items()
{
Subject.InsertMany(_basicList);
Subject.HasItems().Should().BeTrue();
await Subject.InsertManyAsync(_basicList);
(await Subject.HasItemsAsync()).Should().BeTrue();
}
[Test]
public void single_should_throw_on_empty()
{
Assert.Throws<InvalidOperationException>(() => Subject.Single());
Assert.ThrowsAsync<InvalidOperationException>(() => Subject.SingleAsync());
}
[Test]
public void should_be_able_to_get_single()
public async Task should_be_able_to_get_single()
{
Subject.Insert(_basicList[0]);
Subject.Single().Should().BeEquivalentTo(_basicList[0]);
await Subject.InsertAsync(_basicList[0]);
(await Subject.SingleAsync()).Should().BeEquivalentTo(_basicList[0]);
}
[Test]
public void single_or_default_on_empty_table_should_return_null()
public async Task single_or_default_on_empty_table_should_return_null()
{
Subject.SingleOrDefault().Should().BeNull();
(await Subject.SingleOrDefaultAsync()).Should().BeNull();
}
[Test]
public void getting_model_with_invalid_id_should_throw()
{
Assert.Throws<ModelNotFoundException>(() => Subject.Get(12));
Assert.ThrowsAsync<ModelNotFoundException>(() => Subject.GetAsync(12));
}
[Test]
public void get_all_with_empty_db_should_return_empty_list()
public async Task get_all_with_empty_db_should_return_empty_list()
{
Subject.All().Should().BeEmpty();
(await Subject.AllAsync().ToListAsync()).Should().BeEmpty();
}
[Test]
public void should_be_able_to_call_ToList_on_empty_queryable()
public async Task should_be_able_to_call_ToList_on_empty_queryable()
{
Subject.All().ToList().Should().BeEmpty();
(await Subject.AllAsync().ToListAsync()).Should().BeEmpty();
}
[TestCase(1, 2)]
[TestCase(2, 2)]
[TestCase(3, 1)]
public void get_paged_should_work(int page, int count)
public async Task get_paged_should_work(int page, int count)
{
Subject.InsertMany(_basicList);
await Subject.InsertManyAsync(_basicList);
var data = Subject.GetPaged(new PagingSpec<ScheduledTask>() { Page = page, PageSize = 2, SortKey = "LastExecution", SortDirection = SortDirection.Descending });
data.Page.Should().Be(page);
@ -315,9 +316,9 @@ public void get_paged_should_work(int page, int count)
[TestCase(1, 2)]
[TestCase(2, 2)]
[TestCase(3, 1)]
public void get_paged_should_work_with_null_sort_key(int page, int count)
public async Task get_paged_should_work_with_null_sort_key(int page, int count)
{
Subject.InsertMany(_basicList);
await Subject.InsertManyAsync(_basicList);
var data = Subject.GetPaged(new PagingSpec<ScheduledTask>() { Page = page, PageSize = 2, SortDirection = SortDirection.Descending });
data.Page.Should().Be(page);

View file

@ -5,11 +5,15 @@
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using System.Runtime.CompilerServices;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Dapper;
using NLog;
using NzbDrone.Common.Instrumentation;
using NzbDrone.Core.Datastore.Events;
using NzbDrone.Core.Datastore.Extensions;
using NzbDrone.Core.Messaging.Events;
using Polly;
using Polly.Retry;
@ -20,25 +24,45 @@ public interface IBasicRepository<TModel>
where TModel : ModelBase, new()
{
IEnumerable<TModel> All();
IAsyncEnumerable<TModel> AllAsync(CancellationToken cancellationToken = default);
int Count();
Task<int> CountAsync(CancellationToken cancellationToken = default);
TModel Find(int id);
Task<TModel> FindAsync(int id, CancellationToken cancellationToken = default);
TModel Get(int id);
Task<TModel> GetAsync(int id, CancellationToken cancellationToken = default);
TModel Insert(TModel model);
Task<TModel> InsertAsync(TModel model, CancellationToken cancellationToken = default);
TModel Update(TModel model);
Task<TModel> UpdateAsync(TModel model, CancellationToken cancellationToken = default);
TModel Upsert(TModel model);
Task<TModel> UpsertAsync(TModel model, CancellationToken cancellationToken = default);
void SetFields(TModel model, params Expression<Func<TModel, object>>[] properties);
Task SetFieldsAsync(TModel model, params Expression<Func<TModel, object>>[] properties);
void Delete(TModel model);
Task DeleteAsync(TModel model, CancellationToken cancellationToken = default);
void Delete(int id);
Task DeleteAsync(int id, CancellationToken cancellationToken = default);
IEnumerable<TModel> Get(IEnumerable<int> ids);
IAsyncEnumerable<TModel> GetAsync(IEnumerable<int> ids, CancellationToken cancellationToken = default);
void InsertMany(IList<TModel> model);
Task InsertManyAsync(IList<TModel> models, CancellationToken cancellationToken = default);
void UpdateMany(IList<TModel> model);
Task UpdateManyAsync(IList<TModel> models, CancellationToken cancellationToken = default);
void SetFields(IList<TModel> models, params Expression<Func<TModel, object>>[] properties);
Task SetFieldsAsync(IList<TModel> models, params Expression<Func<TModel, object>>[] properties);
void DeleteMany(List<TModel> model);
Task DeleteManyAsync(List<TModel> models, CancellationToken cancellationToken = default);
void DeleteMany(IEnumerable<int> ids);
Task DeleteManyAsync(IEnumerable<int> ids, CancellationToken cancellationToken = default);
void Purge(bool vacuum = false);
Task PurgeAsync(bool vacuum = false, CancellationToken cancellationToken = default);
bool HasItems();
Task<bool> HasItemsAsync(CancellationToken cancellationToken = default);
TModel Single();
Task<TModel> SingleAsync(CancellationToken cancellationToken = default);
TModel SingleOrDefault();
Task<TModel> SingleOrDefaultAsync(CancellationToken cancellationToken = default);
PagingSpec<TModel> GetPaged(PagingSpec<TModel> pagingSpec);
}
@ -95,10 +119,16 @@ public BasicRepository(IDatabase database, IEventAggregator eventAggregator)
protected virtual List<TModel> Query(SqlBuilder builder) => _database.Query<TModel>(builder).ToList();
protected virtual IAsyncEnumerable<TModel> QueryAsync(SqlBuilder builder, CancellationToken cancellationToken = default) => _database.QueryAsync<TModel>(builder, cancellationToken);
protected virtual List<TModel> QueryDistinct(SqlBuilder builder) => _database.QueryDistinct<TModel>(builder).ToList();
protected virtual IAsyncEnumerable<TModel> QueryDistinctAsync(SqlBuilder builder, CancellationToken cancellationToken = default) => _database.QueryDistinctAsync<TModel>(builder, cancellationToken);
protected List<TModel> Query(Expression<Func<TModel, bool>> where) => Query(Builder().Where(where));
protected IAsyncEnumerable<TModel> QueryAsync(Expression<Func<TModel, bool>> where, CancellationToken cancellationToken = default) => QueryAsync(Builder().Where(where), cancellationToken);
public int Count()
{
using (var conn = _database.OpenConnection())
@ -107,11 +137,25 @@ public int Count()
}
}
public async Task<int> CountAsync(CancellationToken cancellationToken = default)
{
cancellationToken.ThrowIfCancellationRequested();
await using var conn = await _database.OpenConnectionAsync(cancellationToken);
return await conn.ExecuteScalarAsync<int>($"SELECT COUNT(*) FROM \"{_table}\"");
}
public virtual IEnumerable<TModel> All()
{
return Query(Builder());
}
public virtual IAsyncEnumerable<TModel> AllAsync(CancellationToken cancellationToken = default)
{
return QueryAsync(Builder(), cancellationToken);
}
public TModel Find(int id)
{
var model = Query(x => x.Id == id).FirstOrDefault();
@ -119,6 +163,11 @@ public TModel Find(int id)
return model;
}
public async Task<TModel> FindAsync(int id, CancellationToken cancellationToken = default)
{
return await QueryAsync(x => x.Id == id, cancellationToken).FirstOrDefaultAsync(cancellationToken);
}
public TModel Get(int id)
{
var model = Find(id);
@ -131,6 +180,18 @@ public TModel Get(int id)
return model;
}
public async Task<TModel> GetAsync(int id, CancellationToken cancellationToken = default)
{
var model = await FindAsync(id, cancellationToken);
if (model == null)
{
throw new ModelNotFoundException(typeof(TModel), id);
}
return model;
}
public IEnumerable<TModel> Get(IEnumerable<int> ids)
{
if (!ids.Any())
@ -148,16 +209,47 @@ public IEnumerable<TModel> Get(IEnumerable<int> ids)
return result;
}
public async IAsyncEnumerable<TModel> GetAsync(IEnumerable<int> ids, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
if (!ids.Any())
{
yield break;
}
var result = QueryAsync(x => ids.Contains(x.Id), cancellationToken);
var countResult = await result.CountAsync(cancellationToken);
if (countResult != ids.Count())
{
throw new ApplicationException($"Expected query to return {ids.Count()} rows but returned {countResult}");
}
await foreach (var model in result)
{
yield return model;
}
}
public TModel SingleOrDefault()
{
return All().SingleOrDefault();
}
public async Task<TModel> SingleOrDefaultAsync(CancellationToken cancellationToken = default)
{
return await AllAsync(cancellationToken).SingleOrDefaultAsync(cancellationToken);
}
public TModel Single()
{
return All().Single();
}
public async Task<TModel> SingleAsync(CancellationToken cancellationToken = default)
{
return await AllAsync(cancellationToken).SingleAsync(cancellationToken);
}
public TModel Insert(TModel model)
{
if (model.Id != 0)
@ -175,6 +267,23 @@ public TModel Insert(TModel model)
return model;
}
public async Task<TModel> InsertAsync(TModel model, CancellationToken cancellationToken = default)
{
if (model.Id != 0)
{
throw new InvalidOperationException("Can't insert model with existing ID " + model.Id);
}
await using (var conn = await _database.OpenConnectionAsync(cancellationToken))
{
model = await InsertAsync(conn, null, model, cancellationToken);
}
ModelCreated(model);
return model;
}
private string GetInsertSql()
{
var sbColumnList = new StringBuilder(null);
@ -220,6 +329,19 @@ private TModel Insert(IDbConnection connection, IDbTransaction transaction, TMod
return model;
}
private async Task<TModel> InsertAsync(IDbConnection connection, IDbTransaction transaction, TModel model, CancellationToken cancellationToken = default)
{
SqlBuilderExtensions.LogQuery(_insertSql, model);
var multi = await RetryStrategy.ExecuteAsync(async static (state, _) => await state.connection.QueryMultipleAsync(state._insertSql, state.model, state.transaction), (connection, _insertSql, model, transaction), cancellationToken);
var multiRead = await multi.ReadAsync();
var id = (int)(multiRead.First().id ?? multiRead.First().Id);
_keyProperty.SetValue(model, id);
return model;
}
public void InsertMany(IList<TModel> models)
{
if (models.Any(x => x.Id != 0))
@ -241,6 +363,24 @@ public void InsertMany(IList<TModel> models)
}
}
public async Task InsertManyAsync(IList<TModel> models, CancellationToken cancellationToken = default)
{
if (models.Any(x => x.Id != 0))
{
throw new InvalidOperationException("Can't insert model with existing ID != 0");
}
await using var conn = await _database.OpenConnectionAsync(cancellationToken);
await using var tran = await conn.BeginTransactionAsync(IsolationLevel.ReadCommitted, cancellationToken);
foreach (var model in models)
{
await InsertAsync(conn, tran, model, cancellationToken);
}
await tran.CommitAsync(cancellationToken);
}
public TModel Update(TModel model)
{
if (model.Id == 0)
@ -258,6 +398,23 @@ public TModel Update(TModel model)
return model;
}
public async Task<TModel> UpdateAsync(TModel model, CancellationToken cancellationToken = default)
{
if (model.Id == 0)
{
throw new InvalidOperationException("Can't update model with ID 0");
}
await using (var conn = await _database.OpenConnectionAsync(cancellationToken))
{
await UpdateFieldsAsync(conn, null, model, _properties, cancellationToken);
}
ModelUpdated(model);
return model;
}
public void UpdateMany(IList<TModel> models)
{
if (models.Any(x => x.Id == 0))
@ -273,11 +430,31 @@ public void UpdateMany(IList<TModel> models)
}
}
public async Task UpdateManyAsync(IList<TModel> models, CancellationToken cancellationToken = default)
{
if (models.Any(x => x.Id == 0))
{
throw new InvalidOperationException("Can't update model with ID 0");
}
await using var conn = await _database.OpenConnectionAsync(cancellationToken);
await using var tran = await conn.BeginTransactionAsync(IsolationLevel.ReadCommitted, cancellationToken);
await UpdateFieldsAsync(conn, tran, models, _properties, cancellationToken);
await tran.CommitAsync(cancellationToken);
}
protected void Delete(Expression<Func<TModel, bool>> where)
{
Delete(Builder().Where<TModel>(where));
}
protected async Task DeleteAsync(Expression<Func<TModel, bool>> where, CancellationToken cancellationToken = default)
{
await DeleteAsync(Builder().Where<TModel>(where), cancellationToken);
}
protected void Delete(SqlBuilder builder)
{
var sql = builder.AddDeleteTemplate(typeof(TModel));
@ -288,16 +465,37 @@ protected void Delete(SqlBuilder builder)
}
}
protected async Task DeleteAsync(SqlBuilder builder, CancellationToken cancellationToken = default)
{
cancellationToken.ThrowIfCancellationRequested();
var sql = builder.AddDeleteTemplate(typeof(TModel));
await using var conn = await _database.OpenConnectionAsync(cancellationToken);
await conn.ExecuteAsync(sql.RawSql, sql.Parameters);
}
public void Delete(TModel model)
{
Delete(model.Id);
}
public async Task DeleteAsync(TModel model, CancellationToken cancellationToken = default)
{
await DeleteAsync(model.Id, cancellationToken);
}
public void Delete(int id)
{
Delete(x => x.Id == id);
}
public async Task DeleteAsync(int id, CancellationToken cancellationToken = default)
{
await DeleteAsync(x => x.Id == id, cancellationToken);
}
public void DeleteMany(IEnumerable<int> ids)
{
if (ids.Any())
@ -306,11 +504,24 @@ public void DeleteMany(IEnumerable<int> ids)
}
}
public async Task DeleteManyAsync(IEnumerable<int> ids, CancellationToken cancellationToken = default)
{
if (ids.Any())
{
await DeleteAsync(x => ids.Contains(x.Id), cancellationToken);
}
}
public void DeleteMany(List<TModel> models)
{
DeleteMany(models.Select(m => m.Id));
}
public async Task DeleteManyAsync(List<TModel> models, CancellationToken cancellationToken = default)
{
await DeleteManyAsync(models.Select(m => m.Id), cancellationToken);
}
public TModel Upsert(TModel model)
{
if (model.Id == 0)
@ -323,6 +534,20 @@ public TModel Upsert(TModel model)
return model;
}
public async Task<TModel> UpsertAsync(TModel model, CancellationToken cancellationToken = default)
{
cancellationToken.ThrowIfCancellationRequested();
if (model.Id == 0)
{
await InsertAsync(model, cancellationToken);
return model;
}
await UpdateAsync(model, cancellationToken);
return model;
}
public void Purge(bool vacuum = false)
{
using (var conn = _database.OpenConnection())
@ -336,6 +561,21 @@ public void Purge(bool vacuum = false)
}
}
public async Task PurgeAsync(bool vacuum = false, CancellationToken cancellationToken = default)
{
cancellationToken.ThrowIfCancellationRequested();
await using (var conn = await _database.OpenConnectionAsync(cancellationToken))
{
await conn.ExecuteAsync($"DELETE FROM \"{_table}\"");
}
if (vacuum)
{
Vacuum();
}
}
protected void Vacuum()
{
_database.Vacuum();
@ -346,6 +586,11 @@ public bool HasItems()
return Count() > 0;
}
public async Task<bool> HasItemsAsync(CancellationToken cancellationToken = default)
{
return await CountAsync(cancellationToken) > 0;
}
public void SetFields(TModel model, params Expression<Func<TModel, object>>[] properties)
{
if (model.Id == 0)
@ -363,6 +608,23 @@ public void SetFields(TModel model, params Expression<Func<TModel, object>>[] pr
ModelUpdated(model);
}
public async Task SetFieldsAsync(TModel model, params Expression<Func<TModel, object>>[] properties)
{
if (model.Id == 0)
{
throw new InvalidOperationException("Attempted to update model without ID");
}
var propertiesToUpdate = properties.Select(x => x.GetMemberName()).ToList();
await using (var conn = await _database.OpenConnectionAsync())
{
await UpdateFieldsAsync(conn, null, model, propertiesToUpdate);
}
ModelUpdated(model);
}
public void SetFields(IList<TModel> models, params Expression<Func<TModel, object>>[] properties)
{
if (models.Any(x => x.Id == 0))
@ -385,6 +647,28 @@ public void SetFields(IList<TModel> models, params Expression<Func<TModel, objec
}
}
public async Task SetFieldsAsync(IList<TModel> models, params Expression<Func<TModel, object>>[] properties)
{
if (models.Any(x => x.Id == 0))
{
throw new InvalidOperationException("Attempted to update model without ID");
}
var propertiesToUpdate = properties.Select(x => x.GetMemberName()).ToList();
await using (var conn = await _database.OpenConnectionAsync())
await using (var tran = await conn.BeginTransactionAsync(IsolationLevel.ReadCommitted))
{
await UpdateFieldsAsync(conn, tran, models, propertiesToUpdate);
await tran.CommitAsync();
}
foreach (var model in models)
{
ModelUpdated(model);
}
}
private string GetUpdateSql(List<PropertyInfo> propertiesToUpdate)
{
var sb = new StringBuilder();
@ -414,6 +698,17 @@ private void UpdateFields(IDbConnection connection, IDbTransaction transaction,
RetryStrategy.Execute(static (state, _) => state.connection.Execute(state.sql, state.model, transaction: state.transaction), (connection, sql, model, transaction));
}
private async Task UpdateFieldsAsync(IDbConnection connection, IDbTransaction transaction, TModel model, List<PropertyInfo> propertiesToUpdate, CancellationToken cancellationToken = default)
{
cancellationToken.ThrowIfCancellationRequested();
var sql = propertiesToUpdate == _properties ? _updateSql : GetUpdateSql(propertiesToUpdate);
SqlBuilderExtensions.LogQuery(sql, model);
await RetryStrategy.ExecuteAsync(async static (state, _) => await state.connection.ExecuteAsync(state.sql, state.model, transaction: state.transaction), (connection, sql, model, transaction), cancellationToken);
}
private void UpdateFields(IDbConnection connection, IDbTransaction transaction, IList<TModel> models, List<PropertyInfo> propertiesToUpdate)
{
var sql = propertiesToUpdate == _properties ? _updateSql : GetUpdateSql(propertiesToUpdate);
@ -426,6 +721,20 @@ private void UpdateFields(IDbConnection connection, IDbTransaction transaction,
RetryStrategy.Execute(static (state, _) => state.connection.Execute(state.sql, state.models, transaction: state.transaction), (connection, sql, models, transaction));
}
private async Task UpdateFieldsAsync(IDbConnection connection, IDbTransaction transaction, IList<TModel> models, List<PropertyInfo> propertiesToUpdate, CancellationToken cancellationToken = default)
{
cancellationToken.ThrowIfCancellationRequested();
var sql = propertiesToUpdate == _properties ? _updateSql : GetUpdateSql(propertiesToUpdate);
foreach (var model in models)
{
SqlBuilderExtensions.LogQuery(sql, model);
}
await RetryStrategy.ExecuteAsync(async static (state, _) => await state.connection.ExecuteAsync(state.sql, state.models, transaction: state.transaction), (connection, sql, models, transaction), cancellationToken);
}
protected virtual SqlBuilder PagedBuilder() => Builder();
protected virtual IEnumerable<TModel> PagedQuery(SqlBuilder sql) => Query(sql);

View file

@ -1,6 +1,8 @@
using System;
using System.Data.Common;
using System.Data.SQLite;
using System.Threading;
using System.Threading.Tasks;
using Dapper;
using NLog;
using NzbDrone.Common.Instrumentation;
@ -10,23 +12,27 @@ namespace NzbDrone.Core.Datastore
public interface IDatabase
{
DbConnection OpenConnection();
Task<DbConnection> OpenConnectionAsync(CancellationToken cancellationToken = default);
Version Version { get; }
int Migration { get; }
DatabaseType DatabaseType { get; }
void Vacuum();
Task VacuumAsync(CancellationToken cancellationToken = default);
}
public class Database : IDatabase
{
private readonly string _databaseName;
private readonly Func<DbConnection> _datamapperFactory;
private readonly Func<CancellationToken, Task<DbConnection>> _datamapperFactoryAsync;
private readonly Logger _logger = NzbDroneLogger.GetLogger(typeof(Database));
public Database(string databaseName, Func<DbConnection> datamapperFactory)
public Database(string databaseName, Func<DbConnection> datamapperFactory, Func<CancellationToken, Task<DbConnection>> datamapperFactoryAsync)
{
_databaseName = databaseName;
_datamapperFactory = datamapperFactory;
_datamapperFactoryAsync = datamapperFactoryAsync;
}
public DbConnection OpenConnection()
@ -34,6 +40,11 @@ public DbConnection OpenConnection()
return _datamapperFactory();
}
public Task<DbConnection> OpenConnectionAsync(CancellationToken cancellationToken = default)
{
return _datamapperFactoryAsync(cancellationToken);
}
public DatabaseType DatabaseType
{
get
@ -83,6 +94,24 @@ public void Vacuum()
_logger.Error(e, "An Error occurred while vacuuming database.");
}
}
public async Task VacuumAsync(CancellationToken cancellationToken = default)
{
try
{
_logger.Info("Vacuuming {0} database", _databaseName);
await using (var db = await _datamapperFactoryAsync(cancellationToken))
{
await db.ExecuteAsync("Vacuum;");
}
_logger.Info("{0} database compressed", _databaseName);
}
catch (Exception e)
{
_logger.Error(e, "An Error occurred while vacuuming database.");
}
}
}
public enum DatabaseType

View file

@ -87,23 +87,42 @@ public IDatabase Create(MigrationContext migrationContext)
}
}
var db = new Database(migrationContext.MigrationType.ToString(), () =>
{
DbConnection conn;
if (connectionInfo.DatabaseType == DatabaseType.SQLite)
var db = new Database(
migrationContext.MigrationType.ToString(),
() =>
{
conn = SQLiteFactory.Instance.CreateConnection();
conn.ConnectionString = connectionInfo.ConnectionString;
}
else
{
conn = new NpgsqlConnection(connectionInfo.ConnectionString);
}
DbConnection conn;
conn.Open();
return conn;
});
if (connectionInfo.DatabaseType == DatabaseType.SQLite)
{
conn = SQLiteFactory.Instance.CreateConnection();
conn.ConnectionString = connectionInfo.ConnectionString;
}
else
{
conn = new NpgsqlConnection(connectionInfo.ConnectionString);
}
conn.Open();
return conn;
},
async cancellationToken =>
{
DbConnection conn;
if (connectionInfo.DatabaseType == DatabaseType.SQLite)
{
conn = SQLiteFactory.Instance.CreateConnection();
conn.ConnectionString = connectionInfo.ConnectionString;
}
else
{
conn = new NpgsqlConnection(connectionInfo.ConnectionString);
}
await conn.OpenAsync(cancellationToken);
return conn;
});
return db;
}

View file

@ -0,0 +1,204 @@
using System;
using System.Collections.Generic;
using System.Data;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
using Dapper;
namespace NzbDrone.Core.Datastore.Extensions
{
public static class SqlMapperAsyncExtensions
{
public static async IAsyncEnumerable<T> QueryAsync<T>(this IDatabase db, string sql, object param = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
await using var conn = await db.OpenConnectionAsync(cancellationToken);
IAsyncEnumerable<T> items;
try
{
items = conn.QueryUnbufferedAsync<T>(sql, param);
}
catch (Exception e)
{
e.Data.Add("SQL", SqlBuilderExtensions.GetSqlLogString(sql, param));
throw;
}
if (TableMapping.Mapper.LazyLoadList.TryGetValue(typeof(T), out var lazyProperties))
{
await foreach (var item in items.WithCancellation(cancellationToken))
{
ApplyLazyLoad(db, item, lazyProperties);
yield return item;
}
}
}
public static async Task<IEnumerable<TReturn>> QueryAsync<TFirst, TSecond, TReturn>(this IDatabase db, string sql, Func<TFirst, TSecond, TReturn> map, object param = null, IDbTransaction transaction = null, bool buffered = true, string splitOn = "Id", int? commandTimeout = null, CommandType? commandType = null, CancellationToken cancellationToken = default)
{
TReturn MapWithLazy(TFirst first, TSecond second)
{
ApplyLazyLoad(db, first);
ApplyLazyLoad(db, second);
return map(first, second);
}
await using var conn = await db.OpenConnectionAsync(cancellationToken);
try
{
return await conn.QueryAsync<TFirst, TSecond, TReturn>(sql, MapWithLazy, param, transaction, buffered, splitOn, commandTimeout, commandType);
}
catch (Exception e)
{
e.Data.Add("SQL", SqlBuilderExtensions.GetSqlLogString(sql, param));
throw;
}
}
public static async Task<IEnumerable<TReturn>> QueryAsync<TFirst, TSecond, TThird, TReturn>(this IDatabase db, string sql, Func<TFirst, TSecond, TThird, TReturn> map, object param = null, IDbTransaction transaction = null, bool buffered = true, string splitOn = "Id", int? commandTimeout = null, CommandType? commandType = null, CancellationToken cancellationToken = default)
{
TReturn MapWithLazy(TFirst first, TSecond second, TThird third)
{
ApplyLazyLoad(db, first);
ApplyLazyLoad(db, second);
ApplyLazyLoad(db, third);
return map(first, second, third);
}
await using var conn = await db.OpenConnectionAsync(cancellationToken);
try
{
return await conn.QueryAsync<TFirst, TSecond, TThird, TReturn>(sql, MapWithLazy, param, transaction, buffered, splitOn, commandTimeout, commandType);
}
catch (Exception e)
{
e.Data.Add("SQL", SqlBuilderExtensions.GetSqlLogString(sql, param));
throw;
}
}
public static async Task<IEnumerable<TReturn>> QueryAsync<TFirst, TSecond, TThird, TFourth, TReturn>(this IDatabase db, string sql, Func<TFirst, TSecond, TThird, TFourth, TReturn> map, object param = null, IDbTransaction transaction = null, bool buffered = true, string splitOn = "Id", int? commandTimeout = null, CommandType? commandType = null, CancellationToken cancellationToken = default)
{
TReturn MapWithLazy(TFirst first, TSecond second, TThird third, TFourth fourth)
{
ApplyLazyLoad(db, first);
ApplyLazyLoad(db, second);
ApplyLazyLoad(db, third);
ApplyLazyLoad(db, fourth);
return map(first, second, third, fourth);
}
await using var conn = await db.OpenConnectionAsync(cancellationToken);
try
{
return await conn.QueryAsync<TFirst, TSecond, TThird, TFourth, TReturn>(sql, MapWithLazy, param, transaction, buffered, splitOn, commandTimeout, commandType);
}
catch (Exception e)
{
e.Data.Add("SQL", SqlBuilderExtensions.GetSqlLogString(sql, param));
throw;
}
}
public static async Task<IEnumerable<TReturn>> QueryAsync<TFirst, TSecond, TThird, TFourth, TFifth, TReturn>(this IDatabase db, string sql, Func<TFirst, TSecond, TThird, TFourth, TFifth, TReturn> map, object param = null, IDbTransaction transaction = null, bool buffered = true, string splitOn = "Id", int? commandTimeout = null, CommandType? commandType = null, CancellationToken cancellationToken = default)
{
TReturn MapWithLazy(TFirst first, TSecond second, TThird third, TFourth fourth, TFifth fifth)
{
ApplyLazyLoad(db, first);
ApplyLazyLoad(db, second);
ApplyLazyLoad(db, third);
ApplyLazyLoad(db, fourth);
ApplyLazyLoad(db, fifth);
return map(first, second, third, fourth, fifth);
}
await using var conn = await db.OpenConnectionAsync(cancellationToken);
try
{
return await conn.QueryAsync<TFirst, TSecond, TThird, TFourth, TFifth, TReturn>(sql, MapWithLazy, param, transaction, buffered, splitOn, commandTimeout, commandType);
}
catch (Exception e)
{
e.Data.Add("SQL", SqlBuilderExtensions.GetSqlLogString(sql, param));
throw;
}
}
public static IAsyncEnumerable<T> QueryAsync<T>(this IDatabase db, SqlBuilder builder, CancellationToken cancellationToken = default)
{
var type = typeof(T);
var sql = builder.Select(type).AddSelectTemplate(type);
return db.QueryAsync<T>(sql.RawSql, sql.Parameters, cancellationToken: cancellationToken);
}
public static IAsyncEnumerable<T> QueryDistinctAsync<T>(this IDatabase db, SqlBuilder builder, CancellationToken cancellationToken = default)
{
var type = typeof(T);
var sql = builder.SelectDistinct(type).AddSelectTemplate(type);
return db.QueryAsync<T>(sql.RawSql, sql.Parameters, cancellationToken: cancellationToken);
}
public static async Task<IEnumerable<T>> QueryJoinedAsync<T, T2>(this IDatabase db, SqlBuilder builder, Func<T, T2, T> mapper)
{
var type = typeof(T);
var sql = builder.Select(type, typeof(T2)).AddSelectTemplate(type);
return await db.QueryAsync(sql.RawSql, mapper, sql.Parameters);
}
public static async Task<IEnumerable<T>> QueryJoinedAsync<T, T2, T3>(this IDatabase db, SqlBuilder builder, Func<T, T2, T3, T> mapper)
{
var type = typeof(T);
var sql = builder.Select(type, typeof(T2), typeof(T3)).AddSelectTemplate(type);
return await db.QueryAsync(sql.RawSql, mapper, sql.Parameters);
}
public static async Task<IEnumerable<T>> QueryJoinedAsync<T, T2, T3, T4>(this IDatabase db, SqlBuilder builder, Func<T, T2, T3, T4, T> mapper)
{
var type = typeof(T);
var sql = builder.Select(type, typeof(T2), typeof(T3), typeof(T4)).AddSelectTemplate(type);
return await db.QueryAsync(sql.RawSql, mapper, sql.Parameters);
}
public static async Task<IEnumerable<T>> QueryJoinedAsync<T, T2, T3, T4, T5>(this IDatabase db, SqlBuilder builder, Func<T, T2, T3, T4, T5, T> mapper)
{
var type = typeof(T);
var sql = builder.Select(type, typeof(T2), typeof(T3), typeof(T4), typeof(T5)).AddSelectTemplate(type);
return await db.QueryAsync(sql.RawSql, mapper, sql.Parameters);
}
private static void ApplyLazyLoad<TModel>(IDatabase db, TModel model)
{
if (TableMapping.Mapper.LazyLoadList.TryGetValue(typeof(TModel), out var lazyProperties))
{
ApplyLazyLoad(db, model, lazyProperties);
}
}
private static void ApplyLazyLoad<TModel>(IDatabase db, TModel model, List<LazyLoadedProperty> lazyProperties)
{
if (model == null)
{
return;
}
foreach (var lazyProperty in lazyProperties)
{
var lazy = (ILazyLoaded)lazyProperty.LazyLoad.Clone();
lazy.Prepare(db, model);
lazyProperty.Property.SetValue(model, lazy);
}
}
}
}

View file

@ -1,5 +1,7 @@
using System;
using System.Data.Common;
using System.Threading;
using System.Threading.Tasks;
namespace NzbDrone.Core.Datastore
{
@ -23,6 +25,11 @@ public DbConnection OpenConnection()
return _database.OpenConnection();
}
public async Task<DbConnection> OpenConnectionAsync(CancellationToken cancellationToken = default)
{
return await _database.OpenConnectionAsync(cancellationToken);
}
public Version Version => _database.Version;
public int Migration => _database.Migration;
@ -33,5 +40,10 @@ public void Vacuum()
{
_database.Vacuum();
}
public async Task VacuumAsync(CancellationToken cancellationToken = default)
{
await _database.VacuumAsync(cancellationToken);
}
}
}

View file

@ -1,5 +1,7 @@
using System;
using System.Data.Common;
using System.Threading;
using System.Threading.Tasks;
using StackExchange.Profiling;
using StackExchange.Profiling.Data;
@ -32,6 +34,18 @@ public DbConnection OpenConnection()
return new ProfiledDbConnection(connection, MiniProfiler.Current);
}
public async Task<DbConnection> OpenConnectionAsync(CancellationToken cancellationToken = default)
{
var connection = await _database.OpenConnectionAsync(cancellationToken);
if (_databaseType == DatabaseType.PostgreSQL)
{
return new ProfiledImplementations.NpgSqlConnection(connection, MiniProfiler.Current);
}
return new ProfiledDbConnection(connection, MiniProfiler.Current);
}
public Version Version => _database.Version;
public int Migration => _database.Migration;
@ -42,5 +56,10 @@ public void Vacuum()
{
_database.Vacuum();
}
public async Task VacuumAsync(CancellationToken cancellationToken = default)
{
await _database.VacuumAsync(cancellationToken);
}
}
}