diff --git a/src/NzbDrone.Core/Download/DownloadClientFactory.cs b/src/NzbDrone.Core/Download/DownloadClientFactory.cs index 1906ffef3..00dc11f32 100644 --- a/src/NzbDrone.Core/Download/DownloadClientFactory.cs +++ b/src/NzbDrone.Core/Download/DownloadClientFactory.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Threading; using FluentValidation.Results; using NLog; using NzbDrone.Common.Extensions; @@ -37,6 +38,11 @@ protected override List Active() return base.Active().Where(c => c.Enable).ToList(); } + protected override IAsyncEnumerable ActiveAsync(CancellationToken cancellationToken = default) + { + return base.ActiveAsync(cancellationToken).Where(c => c.Enable); + } + public override void SetProviderCharacteristics(IDownloadClient provider, DownloadClientDefinition definition) { base.SetProviderCharacteristics(provider, definition); diff --git a/src/NzbDrone.Core/Download/DownloadService.cs b/src/NzbDrone.Core/Download/DownloadService.cs index 17bed18d1..d15e520e4 100644 --- a/src/NzbDrone.Core/Download/DownloadService.cs +++ b/src/NzbDrone.Core/Download/DownloadService.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Threading; using System.Threading.Tasks; using NLog; using NzbDrone.Common.EnsureThat; @@ -118,7 +119,7 @@ public async Task DownloadReport(RemoteEpisode remoteEpisode, int? downloadClien throw new DownloadClientUnavailableException("All '{0}' download clients failed", remoteEpisode.Release.DownloadProtocol); } - private async Task DownloadReport(RemoteEpisode remoteEpisode, IDownloadClient downloadClient) + private async Task DownloadReport(RemoteEpisode remoteEpisode, IDownloadClient downloadClient, CancellationToken cancellationToken = default) { Ensure.That(remoteEpisode.Series, () => remoteEpisode.Series).IsNotNull(); Ensure.That(remoteEpisode.Episodes, () => remoteEpisode.Episodes).HasItems(); @@ -144,7 +145,7 @@ private async Task DownloadReport(RemoteEpisode remoteEpisode, IDownloadClient d if (remoteEpisode.Release.IndexerId > 0) { - indexer = _indexerFactory.GetInstance(_indexerFactory.Get(remoteEpisode.Release.IndexerId)); + indexer = _indexerFactory.GetInstance(await _indexerFactory.GetAsync(remoteEpisode.Release.IndexerId, cancellationToken)); } string downloadClientId; diff --git a/src/NzbDrone.Core/ImportLists/ImportListFactory.cs b/src/NzbDrone.Core/ImportLists/ImportListFactory.cs index 0e20a82af..8e6afabad 100644 --- a/src/NzbDrone.Core/ImportLists/ImportListFactory.cs +++ b/src/NzbDrone.Core/ImportLists/ImportListFactory.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Threading; using FluentValidation.Results; using NLog; using NzbDrone.Core.Messaging.Events; @@ -35,6 +36,11 @@ protected override List Active() return base.Active().Where(c => c.Enable).ToList(); } + protected override IAsyncEnumerable ActiveAsync(CancellationToken cancellationToken = default) + { + return base.ActiveAsync(cancellationToken).Where(c => c.Enable); + } + public override void SetProviderCharacteristics(IImportList provider, ImportListDefinition definition) { base.SetProviderCharacteristics(provider, definition); diff --git a/src/NzbDrone.Core/Indexers/IndexerFactory.cs b/src/NzbDrone.Core/Indexers/IndexerFactory.cs index b22e1b669..6d894cbbb 100644 --- a/src/NzbDrone.Core/Indexers/IndexerFactory.cs +++ b/src/NzbDrone.Core/Indexers/IndexerFactory.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Threading; using FluentValidation.Results; using NLog; using NzbDrone.Common.Extensions; @@ -42,6 +43,11 @@ protected override List Active() return base.Active().Where(c => c.Enable).ToList(); } + protected override IAsyncEnumerable ActiveAsync(CancellationToken cancellationToken = default) + { + return base.ActiveAsync(cancellationToken).Where(c => c.Enable); + } + public override void SetProviderCharacteristics(IIndexer provider, IndexerDefinition definition) { base.SetProviderCharacteristics(provider, definition); diff --git a/src/NzbDrone.Core/Notifications/NotificationFactory.cs b/src/NzbDrone.Core/Notifications/NotificationFactory.cs index e4230fefe..c7dd625a9 100644 --- a/src/NzbDrone.Core/Notifications/NotificationFactory.cs +++ b/src/NzbDrone.Core/Notifications/NotificationFactory.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Threading; using FluentValidation.Results; using NLog; using NzbDrone.Core.Messaging.Events; @@ -42,6 +43,11 @@ protected override List Active() return base.Active().Where(c => c.Enable).ToList(); } + protected override IAsyncEnumerable ActiveAsync(CancellationToken cancellationToken = default) + { + return base.ActiveAsync(cancellationToken).Where(c => c.Enable); + } + public List OnGrabEnabled(bool filterBlockedNotifications = true) { if (filterBlockedNotifications) diff --git a/src/NzbDrone.Core/ThingiProvider/IProviderFactory.cs b/src/NzbDrone.Core/ThingiProvider/IProviderFactory.cs index 0cfeed9a8..c7ada2a66 100644 --- a/src/NzbDrone.Core/ThingiProvider/IProviderFactory.cs +++ b/src/NzbDrone.Core/ThingiProvider/IProviderFactory.cs @@ -1,4 +1,6 @@ using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; using FluentValidation.Results; namespace NzbDrone.Core.ThingiProvider @@ -8,16 +10,27 @@ public interface IProviderFactory where TProvider : IProvider { List All(); + IAsyncEnumerable AllAsync(CancellationToken cancellationToken = default); List GetAvailableProviders(); + IAsyncEnumerable GetAvailableProvidersAsync(CancellationToken cancellationToken = default); bool Exists(int id); + Task ExistsAsync(int id, CancellationToken cancellationToken = default); TProviderDefinition Find(int id); + Task FindAsync(int id, CancellationToken cancellationToken = default); TProviderDefinition Get(int id); + Task GetAsync(int id, CancellationToken cancellationToken = default); IEnumerable Get(IEnumerable ids); + IAsyncEnumerable GetAsync(IEnumerable ids, CancellationToken cancellationToken = default); TProviderDefinition Create(TProviderDefinition definition); + Task CreateAsync(TProviderDefinition definition, CancellationToken cancellationToken = default); void Update(TProviderDefinition definition); + Task UpdateAsync(TProviderDefinition definition, CancellationToken cancellationToken = default); IEnumerable Update(IEnumerable definitions); + Task> UpdateAsync(IEnumerable definitions, CancellationToken cancellationToken = default); void Delete(int id); + Task DeleteAsync(int id, CancellationToken cancellationToken = default); void Delete(IEnumerable ids); + Task DeleteAsync(IEnumerable ids, CancellationToken cancellationToken = default); IEnumerable GetDefaultDefinitions(); IEnumerable GetPresetDefinitions(TProviderDefinition providerDefinition); void SetProviderCharacteristics(TProviderDefinition definition); @@ -26,5 +39,6 @@ public interface IProviderFactory ValidationResult Test(TProviderDefinition definition); object RequestAction(TProviderDefinition definition, string action, IDictionary query); List AllForTag(int tagId); + IAsyncEnumerable AllForTagAsync(int tagId, CancellationToken cancellationToken = default); } } diff --git a/src/NzbDrone.Core/ThingiProvider/ProviderFactory.cs b/src/NzbDrone.Core/ThingiProvider/ProviderFactory.cs index 782141963..e4b18bc52 100644 --- a/src/NzbDrone.Core/ThingiProvider/ProviderFactory.cs +++ b/src/NzbDrone.Core/ThingiProvider/ProviderFactory.cs @@ -1,6 +1,8 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Threading; +using System.Threading.Tasks; using FluentValidation.Results; using Microsoft.Extensions.DependencyInjection; using NLog; @@ -39,6 +41,11 @@ public List All() return _providerRepository.All().ToList(); } + public IAsyncEnumerable AllAsync(CancellationToken cancellationToken = default) + { + return _providerRepository.AllAsync(cancellationToken); + } + public IEnumerable GetDefaultDefinitions() { foreach (var provider in _providers) @@ -91,26 +98,51 @@ public List GetAvailableProviders() return Active().Select(GetInstance).ToList(); } + public IAsyncEnumerable GetAvailableProvidersAsync(CancellationToken cancellationToken = default) + { + return ActiveAsync(cancellationToken).Select(GetInstance); + } + public bool Exists(int id) { return _providerRepository.Find(id) != null; } + public async Task ExistsAsync(int id, CancellationToken cancellationToken = default) + { + return await _providerRepository.FindAsync(id, cancellationToken) is not null; + } + public TProviderDefinition Get(int id) { return _providerRepository.Get(id); } + public async Task GetAsync(int id, CancellationToken cancellationToken = default) + { + return await _providerRepository.GetAsync(id, cancellationToken); + } + public IEnumerable Get(IEnumerable ids) { return _providerRepository.Get(ids); } + public IAsyncEnumerable GetAsync(IEnumerable ids, CancellationToken cancellationToken = default) + { + return _providerRepository.GetAsync(ids, cancellationToken); + } + public TProviderDefinition Find(int id) { return _providerRepository.Find(id); } + public async Task FindAsync(int id, CancellationToken cancellationToken = default) + { + return await _providerRepository.FindAsync(id, cancellationToken); + } + public virtual TProviderDefinition Create(TProviderDefinition definition) { var result = _providerRepository.Insert(definition); @@ -119,12 +151,26 @@ public virtual TProviderDefinition Create(TProviderDefinition definition) return result; } + public virtual async Task CreateAsync(TProviderDefinition definition, CancellationToken cancellationToken = default) + { + var result = await _providerRepository.InsertAsync(definition, cancellationToken); + _eventAggregator.PublishEvent(new ProviderAddedEvent(result)); + + return result; + } + public virtual void Update(TProviderDefinition definition) { _providerRepository.Update(definition); _eventAggregator.PublishEvent(new ProviderUpdatedEvent(definition)); } + public virtual async Task UpdateAsync(TProviderDefinition definition, CancellationToken cancellationToken = default) + { + await _providerRepository.UpdateAsync(definition, cancellationToken); + _eventAggregator.PublishEvent(new ProviderUpdatedEvent(definition)); + } + public virtual IEnumerable Update(IEnumerable definitions) { _providerRepository.UpdateMany(definitions.ToList()); @@ -137,12 +183,30 @@ public virtual IEnumerable Update(IEnumerable> UpdateAsync(IEnumerable definitions, CancellationToken cancellationToken = default) + { + await _providerRepository.UpdateManyAsync(definitions.ToList(), cancellationToken); + + foreach (var definition in definitions) + { + _eventAggregator.PublishEvent(new ProviderUpdatedEvent(definition)); + } + + return definitions; + } + public void Delete(int id) { _providerRepository.Delete(id); _eventAggregator.PublishEvent(new ProviderDeletedEvent(id)); } + public async Task DeleteAsync(int id, CancellationToken cancellationToken = default) + { + await _providerRepository.DeleteAsync(id, cancellationToken); + _eventAggregator.PublishEvent(new ProviderDeletedEvent(id)); + } + public void Delete(IEnumerable ids) { _providerRepository.DeleteMany(ids); @@ -153,6 +217,16 @@ public void Delete(IEnumerable ids) } } + public async Task DeleteAsync(IEnumerable ids, CancellationToken cancellationToken = default) + { + await _providerRepository.DeleteManyAsync(ids, cancellationToken); + + foreach (var id in ids) + { + _eventAggregator.PublishEvent(new ProviderDeletedEvent(id)); + } + } + public TProvider GetInstance(TProviderDefinition definition) { var type = GetImplementation(definition); @@ -185,6 +259,11 @@ protected virtual List Active() return All().Where(c => c.Settings.Validate().IsValid).ToList(); } + protected virtual IAsyncEnumerable ActiveAsync(CancellationToken cancellationToken = default) + { + return AllAsync(cancellationToken).Where(c => c.Settings.Validate().IsValid); + } + public void SetProviderCharacteristics(TProviderDefinition definition) { GetInstance(definition); @@ -212,5 +291,10 @@ public List AllForTag(int tagId) return All().Where(p => p.Tags.Contains(tagId)) .ToList(); } + + public IAsyncEnumerable AllForTagAsync(int tagId, CancellationToken cancellationToken = default) + { + return AllAsync(cancellationToken).Where(p => p.Tags.Contains(tagId)); + } } } diff --git a/src/NzbDrone.Core/ThingiProvider/ProviderRepository.cs b/src/NzbDrone.Core/ThingiProvider/ProviderRepository.cs index 44000f8d3..2600d9612 100644 --- a/src/NzbDrone.Core/ThingiProvider/ProviderRepository.cs +++ b/src/NzbDrone.Core/ThingiProvider/ProviderRepository.cs @@ -1,6 +1,8 @@ using System.Collections.Generic; +using System.Runtime.CompilerServices; using System.Text.Json; using System.Text.Json.Serialization; +using System.Threading; using Dapper; using NzbDrone.Common.Extensions; using NzbDrone.Common.Reflection; @@ -69,5 +71,37 @@ protected override List Query(SqlBuilder builder) return results; } + + protected override async IAsyncEnumerable QueryAsync(SqlBuilder builder, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + cancellationToken.ThrowIfCancellationRequested(); + + var type = typeof(TProviderDefinition); + var sql = builder.Select(type).AddSelectTemplate(type); + + await using var conn = await _database.OpenConnectionAsync(cancellationToken); + await using var reader = await conn.ExecuteReaderAsync(sql.RawSql, sql.Parameters); + + var parser = reader.GetRowParser(typeof(TProviderDefinition)); + var settingsIndex = reader.GetOrdinal(nameof(ProviderDefinition.Settings)); + + while (await reader.ReadAsync(cancellationToken)) + { + var body = await reader.IsDBNullAsync(settingsIndex, cancellationToken) ? null : reader.GetString(settingsIndex); + var item = parser(reader); + var impType = typeof(IProviderConfig).Assembly.FindTypeByName(item.ConfigContract); + + if (body.IsNullOrWhiteSpace() || impType == null) + { + item.Settings = NullConfig.Instance; + } + else + { + item.Settings = (IProviderConfig)JsonSerializer.Deserialize(body, impType, _serializerSettings); + } + + yield return item; + } + } } } diff --git a/src/Sonarr.Api.V5/Connections/ConnectionController.cs b/src/Sonarr.Api.V5/Connections/ConnectionController.cs index c5dc47d1f..e387c6dcf 100644 --- a/src/Sonarr.Api.V5/Connections/ConnectionController.cs +++ b/src/Sonarr.Api.V5/Connections/ConnectionController.cs @@ -19,13 +19,13 @@ public ConnectionController(IBroadcastSignalRMessage signalRBroadcaster, Notific } [NonAction] - public override Results>, BadRequest> UpdateProvider([FromBody] ConnectionBulkResource providerResource) + public override Task>, BadRequest>> UpdateProvider([FromBody] ConnectionBulkResource providerResource, CancellationToken cancellationToken = default) { throw new NotImplementedException(); } [NonAction] - public override NoContent DeleteProviders([FromBody] ConnectionBulkResource resource) + public override Task DeleteProviders([FromBody] ConnectionBulkResource resource, CancellationToken cancellationToken = default) { throw new NotImplementedException(); } diff --git a/src/Sonarr.Api.V5/Metadata/MetadataController.cs b/src/Sonarr.Api.V5/Metadata/MetadataController.cs index 1435db9bb..dd9af510b 100644 --- a/src/Sonarr.Api.V5/Metadata/MetadataController.cs +++ b/src/Sonarr.Api.V5/Metadata/MetadataController.cs @@ -19,13 +19,13 @@ public MetadataController(IBroadcastSignalRMessage signalRBroadcaster, IMetadata } [NonAction] - public override Results>, BadRequest> UpdateProvider([FromBody] MetadataBulkResource providerResource) + public override Task>, BadRequest>> UpdateProvider([FromBody] MetadataBulkResource providerResource, CancellationToken cancellationToken = default) { throw new NotImplementedException(); } [NonAction] - public override NoContent DeleteProviders([FromBody] MetadataBulkResource resource) + public override Task DeleteProviders([FromBody] MetadataBulkResource resource, CancellationToken cancellationToken = default) { throw new NotImplementedException(); } diff --git a/src/Sonarr.Api.V5/Provider/ProviderControllerBase.cs b/src/Sonarr.Api.V5/Provider/ProviderControllerBase.cs index ff95083f9..f1ad8c284 100644 --- a/src/Sonarr.Api.V5/Provider/ProviderControllerBase.cs +++ b/src/Sonarr.Api.V5/Provider/ProviderControllerBase.cs @@ -59,9 +59,9 @@ protected override TProviderResource GetResourceById(int id) [HttpGet] [Produces("application/json")] - public Ok> GetAll() + public async Task>> GetAll(CancellationToken cancellationToken = default) { - var providerDefinitions = _providerFactory.All(); + var providerDefinitions = await _providerFactory.AllAsync(cancellationToken).ToListAsync(cancellationToken: cancellationToken); var result = new List(providerDefinitions.Count); @@ -78,7 +78,7 @@ public Ok> GetAll() [RestPostById] [Consumes("application/json")] [Produces("application/json")] - public Results, NotFound> CreateProvider([FromBody] TProviderResource providerResource, [FromQuery] bool skipTesting = false, [FromQuery] SkipValidation skipValidation = SkipValidation.None) + public async Task, NotFound>> CreateProvider([FromBody] TProviderResource providerResource, [FromQuery] bool skipTesting = false, [FromQuery] SkipValidation skipValidation = SkipValidation.None, CancellationToken cancellationToken = default) { var providerDefinition = GetDefinition(providerResource, null, skipValidation, false); @@ -87,7 +87,7 @@ public Results, NotFound> CreateProvider([FromBody] T Test(providerDefinition, skipValidation); } - providerDefinition = _providerFactory.Create(providerDefinition); + providerDefinition = await _providerFactory.CreateAsync(providerDefinition, cancellationToken); return TypedCreated(providerDefinition.Id); } @@ -95,10 +95,10 @@ public Results, NotFound> CreateProvider([FromBody] T [RestPutById] [Consumes("application/json")] [Produces("application/json")] - public Results, NotFound> UpdateProvider([FromRoute] int id, [FromBody] TProviderResource providerResource, [FromQuery] bool skipTesting = false, [FromQuery] SkipValidation skipValidation = SkipValidation.None) + public async Task, NotFound>> UpdateProvider([FromRoute] int id, [FromBody] TProviderResource providerResource, [FromQuery] bool skipTesting = false, [FromQuery] SkipValidation skipValidation = SkipValidation.None, CancellationToken cancellationToken = default) { // TODO: Remove fallback to Id from body in next API version bump - var existingDefinition = _providerFactory.Find(id) ?? _providerFactory.Find(providerResource.Id); + var existingDefinition = await _providerFactory.FindAsync(id, cancellationToken) ?? await _providerFactory.FindAsync(providerResource.Id, cancellationToken); if (existingDefinition == null) { @@ -118,7 +118,7 @@ public Results, NotFound> UpdateProvider([FromRoute] if (hasDefinitionChanged) { - _providerFactory.Update(providerDefinition); + await _providerFactory.UpdateAsync(providerDefinition, cancellationToken); } return TypedAccepted(existingDefinition.Id); @@ -127,14 +127,14 @@ public Results, NotFound> UpdateProvider([FromRoute] [HttpPut("bulk")] [Consumes("application/json")] [Produces("application/json")] - public virtual Results>, BadRequest> UpdateProvider([FromBody] TBulkProviderResource providerResource) + public virtual async Task>, BadRequest>> UpdateProvider([FromBody] TBulkProviderResource providerResource, CancellationToken cancellationToken = default) { if (!providerResource.Ids.Any()) { throw new BadRequestException("ids must be provided"); } - var definitionsToUpdate = _providerFactory.Get(providerResource.Ids).ToList(); + var definitionsToUpdate = await _providerFactory.GetAsync(providerResource.Ids, cancellationToken).ToListAsync(cancellationToken); foreach (var definition in definitionsToUpdate) { @@ -162,7 +162,7 @@ public virtual Results>, BadRequest> UpdatePro _bulkResourceMapper.UpdateModel(providerResource, definitionsToUpdate); - return TypedResults.Ok(_providerFactory.Update(definitionsToUpdate).Select(x => _resourceMapper.ToResource(x))); + return TypedResults.Ok((await _providerFactory.UpdateAsync(definitionsToUpdate, cancellationToken)).Select(x => _resourceMapper.ToResource(x))); } private TProviderDefinition GetDefinition(TProviderResource providerResource, TProviderDefinition? existingDefinition, SkipValidation skipValidation, bool forceValidate) @@ -178,18 +178,18 @@ private TProviderDefinition GetDefinition(TProviderResource providerResource, TP } [RestDeleteById] - public NoContent DeleteProvider(int id) + public async Task DeleteProvider(int id, CancellationToken cancellationToken = default) { - _providerFactory.Delete(id); + await _providerFactory.DeleteAsync(id, cancellationToken); return TypedResults.NoContent(); } [HttpDelete("bulk")] [Consumes("application/json")] - public virtual NoContent DeleteProviders([FromBody] TBulkProviderResource resource) + public virtual async Task DeleteProviders([FromBody] TBulkProviderResource resource, CancellationToken cancellationToken = default) { - _providerFactory.Delete(resource.Ids); + await _providerFactory.DeleteAsync(resource.Ids, cancellationToken); return TypedResults.NoContent(); } @@ -220,9 +220,9 @@ public Ok> GetTemplates() [SkipValidation(true, false)] [HttpPost("test")] [Consumes("application/json")] - public NoContent Test([FromBody] TProviderResource providerResource, [FromQuery] SkipValidation skipValidation = SkipValidation.None) + public async Task Test([FromBody] TProviderResource providerResource, [FromQuery] SkipValidation skipValidation = SkipValidation.None, CancellationToken cancellationToken = default) { - var existingDefinition = providerResource.Id > 0 ? _providerFactory.Find(providerResource.Id) : null; + var existingDefinition = providerResource.Id > 0 ? await _providerFactory.FindAsync(providerResource.Id, cancellationToken) : null; var providerDefinition = GetDefinition(providerResource, existingDefinition, skipValidation, true); Test(providerDefinition, skipValidation); @@ -232,14 +232,13 @@ public NoContent Test([FromBody] TProviderResource providerResource, [FromQuery] [HttpPost("testall")] [Produces("application/json")] - public Results>, BadRequest>> TestAll() + public async Task>, BadRequest>>> TestAll(CancellationToken cancellationToken = default) { - var providerDefinitions = _providerFactory.All() - .Where(c => c.Settings.Validate().IsValid && c.Enable) - .ToList(); var result = new List(); - foreach (var definition in providerDefinitions) + await foreach (var definition in _providerFactory.AllAsync(cancellationToken) + .Where(c => c.Settings.Validate().IsValid && c.Enable) + .WithCancellation(cancellationToken)) { var validationFailures = new List(); @@ -260,9 +259,9 @@ public Results>, BadRequest RequestAction([FromRoute] string name, [FromBody] TProviderResource providerResource) + public async Task> RequestAction([FromRoute] string name, [FromBody] TProviderResource providerResource, CancellationToken cancellationToken = default) { - var existingDefinition = providerResource.Id > 0 ? _providerFactory.Find(providerResource.Id) : null; + var existingDefinition = providerResource.Id > 0 ? await _providerFactory.FindAsync(providerResource.Id, cancellationToken) : null; var providerDefinition = GetDefinition(providerResource, existingDefinition, SkipValidation.All, false); var query = Request.Query.ToDictionary(x => x.Key, x => x.Value.ToString());