Add async overloaded methods for Providers

This commit is contained in:
Bogdan 2026-04-27 17:56:45 +03:00
parent b1ea411400
commit ef76ee3897
11 changed files with 185 additions and 29 deletions

View file

@ -1,6 +1,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using FluentValidation.Results;
using NLog;
using NzbDrone.Common.Extensions;
@ -37,6 +38,11 @@ protected override List<DownloadClientDefinition> Active()
return base.Active().Where(c => c.Enable).ToList();
}
protected override IAsyncEnumerable<DownloadClientDefinition> ActiveAsync(CancellationToken cancellationToken = default)
{
return base.ActiveAsync(cancellationToken).Where(c => c.Enable);
}
public override void SetProviderCharacteristics(IDownloadClient provider, DownloadClientDefinition definition)
{
base.SetProviderCharacteristics(provider, definition);

View file

@ -1,6 +1,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using NLog;
using NzbDrone.Common.EnsureThat;
@ -118,7 +119,7 @@ public async Task DownloadReport(RemoteEpisode remoteEpisode, int? downloadClien
throw new DownloadClientUnavailableException("All '{0}' download clients failed", remoteEpisode.Release.DownloadProtocol);
}
private async Task DownloadReport(RemoteEpisode remoteEpisode, IDownloadClient downloadClient)
private async Task DownloadReport(RemoteEpisode remoteEpisode, IDownloadClient downloadClient, CancellationToken cancellationToken = default)
{
Ensure.That(remoteEpisode.Series, () => remoteEpisode.Series).IsNotNull();
Ensure.That(remoteEpisode.Episodes, () => remoteEpisode.Episodes).HasItems();
@ -144,7 +145,7 @@ private async Task DownloadReport(RemoteEpisode remoteEpisode, IDownloadClient d
if (remoteEpisode.Release.IndexerId > 0)
{
indexer = _indexerFactory.GetInstance(_indexerFactory.Get(remoteEpisode.Release.IndexerId));
indexer = _indexerFactory.GetInstance(await _indexerFactory.GetAsync(remoteEpisode.Release.IndexerId, cancellationToken));
}
string downloadClientId;

View file

@ -1,6 +1,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using FluentValidation.Results;
using NLog;
using NzbDrone.Core.Messaging.Events;
@ -35,6 +36,11 @@ protected override List<ImportListDefinition> Active()
return base.Active().Where(c => c.Enable).ToList();
}
protected override IAsyncEnumerable<ImportListDefinition> ActiveAsync(CancellationToken cancellationToken = default)
{
return base.ActiveAsync(cancellationToken).Where(c => c.Enable);
}
public override void SetProviderCharacteristics(IImportList provider, ImportListDefinition definition)
{
base.SetProviderCharacteristics(provider, definition);

View file

@ -1,6 +1,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using FluentValidation.Results;
using NLog;
using NzbDrone.Common.Extensions;
@ -42,6 +43,11 @@ protected override List<IndexerDefinition> Active()
return base.Active().Where(c => c.Enable).ToList();
}
protected override IAsyncEnumerable<IndexerDefinition> ActiveAsync(CancellationToken cancellationToken = default)
{
return base.ActiveAsync(cancellationToken).Where(c => c.Enable);
}
public override void SetProviderCharacteristics(IIndexer provider, IndexerDefinition definition)
{
base.SetProviderCharacteristics(provider, definition);

View file

@ -1,6 +1,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using FluentValidation.Results;
using NLog;
using NzbDrone.Core.Messaging.Events;
@ -42,6 +43,11 @@ protected override List<NotificationDefinition> Active()
return base.Active().Where(c => c.Enable).ToList();
}
protected override IAsyncEnumerable<NotificationDefinition> ActiveAsync(CancellationToken cancellationToken = default)
{
return base.ActiveAsync(cancellationToken).Where(c => c.Enable);
}
public List<INotification> OnGrabEnabled(bool filterBlockedNotifications = true)
{
if (filterBlockedNotifications)

View file

@ -1,4 +1,6 @@
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using FluentValidation.Results;
namespace NzbDrone.Core.ThingiProvider
@ -8,16 +10,27 @@ public interface IProviderFactory<TProvider, TProviderDefinition>
where TProvider : IProvider
{
List<TProviderDefinition> All();
IAsyncEnumerable<TProviderDefinition> AllAsync(CancellationToken cancellationToken = default);
List<TProvider> GetAvailableProviders();
IAsyncEnumerable<TProvider> GetAvailableProvidersAsync(CancellationToken cancellationToken = default);
bool Exists(int id);
Task<bool> ExistsAsync(int id, CancellationToken cancellationToken = default);
TProviderDefinition Find(int id);
Task<TProviderDefinition> FindAsync(int id, CancellationToken cancellationToken = default);
TProviderDefinition Get(int id);
Task<TProviderDefinition> GetAsync(int id, CancellationToken cancellationToken = default);
IEnumerable<TProviderDefinition> Get(IEnumerable<int> ids);
IAsyncEnumerable<TProviderDefinition> GetAsync(IEnumerable<int> ids, CancellationToken cancellationToken = default);
TProviderDefinition Create(TProviderDefinition definition);
Task<TProviderDefinition> CreateAsync(TProviderDefinition definition, CancellationToken cancellationToken = default);
void Update(TProviderDefinition definition);
Task UpdateAsync(TProviderDefinition definition, CancellationToken cancellationToken = default);
IEnumerable<TProviderDefinition> Update(IEnumerable<TProviderDefinition> definitions);
Task<IEnumerable<TProviderDefinition>> UpdateAsync(IEnumerable<TProviderDefinition> definitions, CancellationToken cancellationToken = default);
void Delete(int id);
Task DeleteAsync(int id, CancellationToken cancellationToken = default);
void Delete(IEnumerable<int> ids);
Task DeleteAsync(IEnumerable<int> ids, CancellationToken cancellationToken = default);
IEnumerable<TProviderDefinition> GetDefaultDefinitions();
IEnumerable<TProviderDefinition> GetPresetDefinitions(TProviderDefinition providerDefinition);
void SetProviderCharacteristics(TProviderDefinition definition);
@ -26,5 +39,6 @@ public interface IProviderFactory<TProvider, TProviderDefinition>
ValidationResult Test(TProviderDefinition definition);
object RequestAction(TProviderDefinition definition, string action, IDictionary<string, string> query);
List<TProviderDefinition> AllForTag(int tagId);
IAsyncEnumerable<TProviderDefinition> AllForTagAsync(int tagId, CancellationToken cancellationToken = default);
}
}

View file

@ -1,6 +1,8 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using FluentValidation.Results;
using Microsoft.Extensions.DependencyInjection;
using NLog;
@ -39,6 +41,11 @@ public List<TProviderDefinition> All()
return _providerRepository.All().ToList();
}
public IAsyncEnumerable<TProviderDefinition> AllAsync(CancellationToken cancellationToken = default)
{
return _providerRepository.AllAsync(cancellationToken);
}
public IEnumerable<TProviderDefinition> GetDefaultDefinitions()
{
foreach (var provider in _providers)
@ -91,26 +98,51 @@ public List<TProvider> GetAvailableProviders()
return Active().Select(GetInstance).ToList();
}
public IAsyncEnumerable<TProvider> GetAvailableProvidersAsync(CancellationToken cancellationToken = default)
{
return ActiveAsync(cancellationToken).Select(GetInstance);
}
public bool Exists(int id)
{
return _providerRepository.Find(id) != null;
}
public async Task<bool> ExistsAsync(int id, CancellationToken cancellationToken = default)
{
return await _providerRepository.FindAsync(id, cancellationToken) is not null;
}
public TProviderDefinition Get(int id)
{
return _providerRepository.Get(id);
}
public async Task<TProviderDefinition> GetAsync(int id, CancellationToken cancellationToken = default)
{
return await _providerRepository.GetAsync(id, cancellationToken);
}
public IEnumerable<TProviderDefinition> Get(IEnumerable<int> ids)
{
return _providerRepository.Get(ids);
}
public IAsyncEnumerable<TProviderDefinition> GetAsync(IEnumerable<int> ids, CancellationToken cancellationToken = default)
{
return _providerRepository.GetAsync(ids, cancellationToken);
}
public TProviderDefinition Find(int id)
{
return _providerRepository.Find(id);
}
public async Task<TProviderDefinition> FindAsync(int id, CancellationToken cancellationToken = default)
{
return await _providerRepository.FindAsync(id, cancellationToken);
}
public virtual TProviderDefinition Create(TProviderDefinition definition)
{
var result = _providerRepository.Insert(definition);
@ -119,12 +151,26 @@ public virtual TProviderDefinition Create(TProviderDefinition definition)
return result;
}
public virtual async Task<TProviderDefinition> CreateAsync(TProviderDefinition definition, CancellationToken cancellationToken = default)
{
var result = await _providerRepository.InsertAsync(definition, cancellationToken);
_eventAggregator.PublishEvent(new ProviderAddedEvent<TProvider>(result));
return result;
}
public virtual void Update(TProviderDefinition definition)
{
_providerRepository.Update(definition);
_eventAggregator.PublishEvent(new ProviderUpdatedEvent<TProvider>(definition));
}
public virtual async Task UpdateAsync(TProviderDefinition definition, CancellationToken cancellationToken = default)
{
await _providerRepository.UpdateAsync(definition, cancellationToken);
_eventAggregator.PublishEvent(new ProviderUpdatedEvent<TProvider>(definition));
}
public virtual IEnumerable<TProviderDefinition> Update(IEnumerable<TProviderDefinition> definitions)
{
_providerRepository.UpdateMany(definitions.ToList());
@ -137,12 +183,30 @@ public virtual IEnumerable<TProviderDefinition> Update(IEnumerable<TProviderDefi
return definitions;
}
public virtual async Task<IEnumerable<TProviderDefinition>> UpdateAsync(IEnumerable<TProviderDefinition> definitions, CancellationToken cancellationToken = default)
{
await _providerRepository.UpdateManyAsync(definitions.ToList(), cancellationToken);
foreach (var definition in definitions)
{
_eventAggregator.PublishEvent(new ProviderUpdatedEvent<TProvider>(definition));
}
return definitions;
}
public void Delete(int id)
{
_providerRepository.Delete(id);
_eventAggregator.PublishEvent(new ProviderDeletedEvent<TProvider>(id));
}
public async Task DeleteAsync(int id, CancellationToken cancellationToken = default)
{
await _providerRepository.DeleteAsync(id, cancellationToken);
_eventAggregator.PublishEvent(new ProviderDeletedEvent<TProvider>(id));
}
public void Delete(IEnumerable<int> ids)
{
_providerRepository.DeleteMany(ids);
@ -153,6 +217,16 @@ public void Delete(IEnumerable<int> ids)
}
}
public async Task DeleteAsync(IEnumerable<int> ids, CancellationToken cancellationToken = default)
{
await _providerRepository.DeleteManyAsync(ids, cancellationToken);
foreach (var id in ids)
{
_eventAggregator.PublishEvent(new ProviderDeletedEvent<TProvider>(id));
}
}
public TProvider GetInstance(TProviderDefinition definition)
{
var type = GetImplementation(definition);
@ -185,6 +259,11 @@ protected virtual List<TProviderDefinition> Active()
return All().Where(c => c.Settings.Validate().IsValid).ToList();
}
protected virtual IAsyncEnumerable<TProviderDefinition> ActiveAsync(CancellationToken cancellationToken = default)
{
return AllAsync(cancellationToken).Where(c => c.Settings.Validate().IsValid);
}
public void SetProviderCharacteristics(TProviderDefinition definition)
{
GetInstance(definition);
@ -212,5 +291,10 @@ public List<TProviderDefinition> AllForTag(int tagId)
return All().Where(p => p.Tags.Contains(tagId))
.ToList();
}
public IAsyncEnumerable<TProviderDefinition> AllForTagAsync(int tagId, CancellationToken cancellationToken = default)
{
return AllAsync(cancellationToken).Where(p => p.Tags.Contains(tagId));
}
}
}

View file

@ -1,6 +1,8 @@
using System.Collections.Generic;
using System.Runtime.CompilerServices;
using System.Text.Json;
using System.Text.Json.Serialization;
using System.Threading;
using Dapper;
using NzbDrone.Common.Extensions;
using NzbDrone.Common.Reflection;
@ -69,5 +71,37 @@ protected override List<TProviderDefinition> Query(SqlBuilder builder)
return results;
}
protected override async IAsyncEnumerable<TProviderDefinition> QueryAsync(SqlBuilder builder, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
cancellationToken.ThrowIfCancellationRequested();
var type = typeof(TProviderDefinition);
var sql = builder.Select(type).AddSelectTemplate(type);
await using var conn = await _database.OpenConnectionAsync(cancellationToken);
await using var reader = await conn.ExecuteReaderAsync(sql.RawSql, sql.Parameters);
var parser = reader.GetRowParser<TProviderDefinition>(typeof(TProviderDefinition));
var settingsIndex = reader.GetOrdinal(nameof(ProviderDefinition.Settings));
while (await reader.ReadAsync(cancellationToken))
{
var body = await reader.IsDBNullAsync(settingsIndex, cancellationToken) ? null : reader.GetString(settingsIndex);
var item = parser(reader);
var impType = typeof(IProviderConfig).Assembly.FindTypeByName(item.ConfigContract);
if (body.IsNullOrWhiteSpace() || impType == null)
{
item.Settings = NullConfig.Instance;
}
else
{
item.Settings = (IProviderConfig)JsonSerializer.Deserialize(body, impType, _serializerSettings);
}
yield return item;
}
}
}
}

View file

@ -19,13 +19,13 @@ public ConnectionController(IBroadcastSignalRMessage signalRBroadcaster, Notific
}
[NonAction]
public override Results<Ok<IEnumerable<ConnectionResource>>, BadRequest> UpdateProvider([FromBody] ConnectionBulkResource providerResource)
public override Task<Results<Ok<IEnumerable<ConnectionResource>>, BadRequest>> UpdateProvider([FromBody] ConnectionBulkResource providerResource, CancellationToken cancellationToken = default)
{
throw new NotImplementedException();
}
[NonAction]
public override NoContent DeleteProviders([FromBody] ConnectionBulkResource resource)
public override Task<NoContent> DeleteProviders([FromBody] ConnectionBulkResource resource, CancellationToken cancellationToken = default)
{
throw new NotImplementedException();
}

View file

@ -19,13 +19,13 @@ public MetadataController(IBroadcastSignalRMessage signalRBroadcaster, IMetadata
}
[NonAction]
public override Results<Ok<IEnumerable<MetadataResource>>, BadRequest> UpdateProvider([FromBody] MetadataBulkResource providerResource)
public override Task<Results<Ok<IEnumerable<MetadataResource>>, BadRequest>> UpdateProvider([FromBody] MetadataBulkResource providerResource, CancellationToken cancellationToken = default)
{
throw new NotImplementedException();
}
[NonAction]
public override NoContent DeleteProviders([FromBody] MetadataBulkResource resource)
public override Task<NoContent> DeleteProviders([FromBody] MetadataBulkResource resource, CancellationToken cancellationToken = default)
{
throw new NotImplementedException();
}

View file

@ -59,9 +59,9 @@ protected override TProviderResource GetResourceById(int id)
[HttpGet]
[Produces("application/json")]
public Ok<List<TProviderResource>> GetAll()
public async Task<Ok<List<TProviderResource>>> GetAll(CancellationToken cancellationToken = default)
{
var providerDefinitions = _providerFactory.All();
var providerDefinitions = await _providerFactory.AllAsync(cancellationToken).ToListAsync(cancellationToken: cancellationToken);
var result = new List<TProviderResource>(providerDefinitions.Count);
@ -78,7 +78,7 @@ public Ok<List<TProviderResource>> GetAll()
[RestPostById]
[Consumes("application/json")]
[Produces("application/json")]
public Results<Created<TProviderResource>, NotFound> CreateProvider([FromBody] TProviderResource providerResource, [FromQuery] bool skipTesting = false, [FromQuery] SkipValidation skipValidation = SkipValidation.None)
public async Task<Results<Created<TProviderResource>, NotFound>> CreateProvider([FromBody] TProviderResource providerResource, [FromQuery] bool skipTesting = false, [FromQuery] SkipValidation skipValidation = SkipValidation.None, CancellationToken cancellationToken = default)
{
var providerDefinition = GetDefinition(providerResource, null, skipValidation, false);
@ -87,7 +87,7 @@ public Results<Created<TProviderResource>, NotFound> CreateProvider([FromBody] T
Test(providerDefinition, skipValidation);
}
providerDefinition = _providerFactory.Create(providerDefinition);
providerDefinition = await _providerFactory.CreateAsync(providerDefinition, cancellationToken);
return TypedCreated(providerDefinition.Id);
}
@ -95,10 +95,10 @@ public Results<Created<TProviderResource>, NotFound> CreateProvider([FromBody] T
[RestPutById]
[Consumes("application/json")]
[Produces("application/json")]
public Results<Accepted<TProviderResource>, NotFound> UpdateProvider([FromRoute] int id, [FromBody] TProviderResource providerResource, [FromQuery] bool skipTesting = false, [FromQuery] SkipValidation skipValidation = SkipValidation.None)
public async Task<Results<Accepted<TProviderResource>, NotFound>> UpdateProvider([FromRoute] int id, [FromBody] TProviderResource providerResource, [FromQuery] bool skipTesting = false, [FromQuery] SkipValidation skipValidation = SkipValidation.None, CancellationToken cancellationToken = default)
{
// TODO: Remove fallback to Id from body in next API version bump
var existingDefinition = _providerFactory.Find(id) ?? _providerFactory.Find(providerResource.Id);
var existingDefinition = await _providerFactory.FindAsync(id, cancellationToken) ?? await _providerFactory.FindAsync(providerResource.Id, cancellationToken);
if (existingDefinition == null)
{
@ -118,7 +118,7 @@ public Results<Accepted<TProviderResource>, NotFound> UpdateProvider([FromRoute]
if (hasDefinitionChanged)
{
_providerFactory.Update(providerDefinition);
await _providerFactory.UpdateAsync(providerDefinition, cancellationToken);
}
return TypedAccepted(existingDefinition.Id);
@ -127,14 +127,14 @@ public Results<Accepted<TProviderResource>, NotFound> UpdateProvider([FromRoute]
[HttpPut("bulk")]
[Consumes("application/json")]
[Produces("application/json")]
public virtual Results<Ok<IEnumerable<TProviderResource>>, BadRequest> UpdateProvider([FromBody] TBulkProviderResource providerResource)
public virtual async Task<Results<Ok<IEnumerable<TProviderResource>>, BadRequest>> UpdateProvider([FromBody] TBulkProviderResource providerResource, CancellationToken cancellationToken = default)
{
if (!providerResource.Ids.Any())
{
throw new BadRequestException("ids must be provided");
}
var definitionsToUpdate = _providerFactory.Get(providerResource.Ids).ToList();
var definitionsToUpdate = await _providerFactory.GetAsync(providerResource.Ids, cancellationToken).ToListAsync(cancellationToken);
foreach (var definition in definitionsToUpdate)
{
@ -162,7 +162,7 @@ public virtual Results<Ok<IEnumerable<TProviderResource>>, BadRequest> UpdatePro
_bulkResourceMapper.UpdateModel(providerResource, definitionsToUpdate);
return TypedResults.Ok(_providerFactory.Update(definitionsToUpdate).Select(x => _resourceMapper.ToResource(x)));
return TypedResults.Ok((await _providerFactory.UpdateAsync(definitionsToUpdate, cancellationToken)).Select(x => _resourceMapper.ToResource(x)));
}
private TProviderDefinition GetDefinition(TProviderResource providerResource, TProviderDefinition? existingDefinition, SkipValidation skipValidation, bool forceValidate)
@ -178,18 +178,18 @@ private TProviderDefinition GetDefinition(TProviderResource providerResource, TP
}
[RestDeleteById]
public NoContent DeleteProvider(int id)
public async Task<NoContent> DeleteProvider(int id, CancellationToken cancellationToken = default)
{
_providerFactory.Delete(id);
await _providerFactory.DeleteAsync(id, cancellationToken);
return TypedResults.NoContent();
}
[HttpDelete("bulk")]
[Consumes("application/json")]
public virtual NoContent DeleteProviders([FromBody] TBulkProviderResource resource)
public virtual async Task<NoContent> DeleteProviders([FromBody] TBulkProviderResource resource, CancellationToken cancellationToken = default)
{
_providerFactory.Delete(resource.Ids);
await _providerFactory.DeleteAsync(resource.Ids, cancellationToken);
return TypedResults.NoContent();
}
@ -220,9 +220,9 @@ public Ok<List<TProviderResource>> GetTemplates()
[SkipValidation(true, false)]
[HttpPost("test")]
[Consumes("application/json")]
public NoContent Test([FromBody] TProviderResource providerResource, [FromQuery] SkipValidation skipValidation = SkipValidation.None)
public async Task<NoContent> Test([FromBody] TProviderResource providerResource, [FromQuery] SkipValidation skipValidation = SkipValidation.None, CancellationToken cancellationToken = default)
{
var existingDefinition = providerResource.Id > 0 ? _providerFactory.Find(providerResource.Id) : null;
var existingDefinition = providerResource.Id > 0 ? await _providerFactory.FindAsync(providerResource.Id, cancellationToken) : null;
var providerDefinition = GetDefinition(providerResource, existingDefinition, skipValidation, true);
Test(providerDefinition, skipValidation);
@ -232,14 +232,13 @@ public NoContent Test([FromBody] TProviderResource providerResource, [FromQuery]
[HttpPost("testall")]
[Produces("application/json")]
public Results<Ok<List<ProviderTestAllResult>>, BadRequest<List<ProviderTestAllResult>>> TestAll()
public async Task<Results<Ok<List<ProviderTestAllResult>>, BadRequest<List<ProviderTestAllResult>>>> TestAll(CancellationToken cancellationToken = default)
{
var providerDefinitions = _providerFactory.All()
.Where(c => c.Settings.Validate().IsValid && c.Enable)
.ToList();
var result = new List<ProviderTestAllResult>();
foreach (var definition in providerDefinitions)
await foreach (var definition in _providerFactory.AllAsync(cancellationToken)
.Where(c => c.Settings.Validate().IsValid && c.Enable)
.WithCancellation(cancellationToken))
{
var validationFailures = new List<ValidationFailure>();
@ -260,9 +259,9 @@ public Results<Ok<List<ProviderTestAllResult>>, BadRequest<List<ProviderTestAllR
[HttpPost("action/{name}")]
[Consumes("application/json")]
[Produces("application/json")]
public Results<ContentHttpResult, BadRequest> RequestAction([FromRoute] string name, [FromBody] TProviderResource providerResource)
public async Task<Results<ContentHttpResult, BadRequest>> RequestAction([FromRoute] string name, [FromBody] TProviderResource providerResource, CancellationToken cancellationToken = default)
{
var existingDefinition = providerResource.Id > 0 ? _providerFactory.Find(providerResource.Id) : null;
var existingDefinition = providerResource.Id > 0 ? await _providerFactory.FindAsync(providerResource.Id, cancellationToken) : null;
var providerDefinition = GetDefinition(providerResource, existingDefinition, SkipValidation.All, false);
var query = Request.Query.ToDictionary(x => x.Key, x => x.Value.ToString());