Better synchronous IAsyncEnumerable<T> handling. (#18296)

This commit is contained in:
Pieter-Jan Briers
2023-07-26 02:03:41 +02:00
committed by GitHub
parent 12391b4881
commit 9507520e40

View File

@@ -890,32 +890,10 @@ namespace Content.Server.Database
private IAsyncEnumerable<T> RunDbCommand<T>(Func<IAsyncEnumerable<T>> command) private IAsyncEnumerable<T> RunDbCommand<T>(Func<IAsyncEnumerable<T>> command)
{ {
var enumerable = command(); var enumerable = command();
if (!_synchronous) if (_synchronous)
return enumerable; return new SyncAsyncEnumerable<T>(enumerable);
// IAsyncEnumerable<T> must be drained synchronously and returned as a fake async enumerable. return enumerable;
// If we were to let it go through like normal, it'd do a bunch of bad async stuff and break everything.
var results = new List<T>();
var enumerator = enumerable.GetAsyncEnumerator();
while (true)
{
var result = enumerator.MoveNextAsync();
if (!result.IsCompleted)
{
throw new InvalidOperationException(
"Database async enumerable is running asynchronously. " +
$"This should be impossible when the database is set to synchronous. Count: {results.Count}");
}
if (!result.Result)
break;
results.Add(enumerator.Current);
}
return new FakeAsyncEnumerable<T>(results);
} }
private DbContextOptions<PostgresServerDbContext> CreatePostgresOptions() private DbContextOptions<PostgresServerDbContext> CreatePostgresOptions()
@@ -1048,38 +1026,45 @@ namespace Content.Server.Database
public sealed record PlayTimeUpdate(NetUserId User, string Tracker, TimeSpan Time); public sealed record PlayTimeUpdate(NetUserId User, string Tracker, TimeSpan Time);
internal sealed class FakeAsyncEnumerable<T> : IAsyncEnumerable<T> internal sealed class SyncAsyncEnumerable<T> : IAsyncEnumerable<T>
{ {
private readonly IEnumerable<T> _enumerable; private readonly IAsyncEnumerable<T> _enumerable;
public FakeAsyncEnumerable(IEnumerable<T> enumerable) public SyncAsyncEnumerable(IAsyncEnumerable<T> enumerable)
{ {
_enumerable = enumerable; _enumerable = enumerable;
} }
public IAsyncEnumerator<T> GetAsyncEnumerator(CancellationToken cancellationToken = default) public IAsyncEnumerator<T> GetAsyncEnumerator(CancellationToken cancellationToken = default)
{ {
return new Enumerator(_enumerable.GetEnumerator()); return new Enumerator(_enumerable.GetAsyncEnumerator(cancellationToken));
} }
private sealed class Enumerator : IAsyncEnumerator<T> private sealed class Enumerator : IAsyncEnumerator<T>
{ {
private readonly IEnumerator<T> _enumerator; private readonly IAsyncEnumerator<T> _enumerator;
public Enumerator(IEnumerator<T> enumerator) public Enumerator(IAsyncEnumerator<T> enumerator)
{ {
_enumerator = enumerator; _enumerator = enumerator;
} }
public ValueTask DisposeAsync() public ValueTask DisposeAsync()
{ {
_enumerator.Dispose(); var task = _enumerator.DisposeAsync();
return ValueTask.CompletedTask; if (!task.IsCompleted)
throw new InvalidOperationException("DisposeAsync did not complete synchronously.");
return task;
} }
public ValueTask<bool> MoveNextAsync() public ValueTask<bool> MoveNextAsync()
{ {
return new ValueTask<bool>(_enumerator.MoveNext()); var task = _enumerator.MoveNextAsync();
if (!task.IsCompleted)
throw new InvalidOperationException("MoveNextAsync did not complete synchronously.");
return task;
} }
public T Current => _enumerator.Current; public T Current => _enumerator.Current;