(Probably) fix random integration test failures. (#18270)

This commit is contained in:
Pieter-Jan Briers
2023-07-25 03:10:50 +02:00
committed by GitHub
parent 5fa1849948
commit 978887bf03
5 changed files with 164 additions and 36 deletions

View File

@@ -14,6 +14,7 @@ using Microsoft.EntityFrameworkCore;
using Microsoft.Extensions.Logging;
using Npgsql;
using Prometheus;
using Robust.Shared.Asynchronous;
using Robust.Shared.Configuration;
using Robust.Shared.ContentPack;
using Robust.Shared.Network;
@@ -291,6 +292,7 @@ namespace Content.Server.Database
[Dependency] private readonly IConfigurationManager _cfg = default!;
[Dependency] private readonly IResourceManager _res = default!;
[Dependency] private readonly ILogManager _logMgr = default!;
[Dependency] private readonly ITaskManager _taskManager = default!;
private ServerDbBase _db = default!;
private LoggingProvider _msLogProvider = default!;
@@ -316,7 +318,7 @@ namespace Content.Server.Database
{
case "sqlite":
SetupSqlite(out var contextFunc, out var inMemory);
_db = new ServerDbSqlite(contextFunc, inMemory, _cfg);
_db = new ServerDbSqlite(contextFunc, inMemory, _cfg, _synchronous);
break;
case "postgres":
var pgOptions = CreatePostgresOptions();
@@ -619,25 +621,25 @@ namespace Content.Server.Database
public IAsyncEnumerable<string> GetAdminLogMessages(LogFilter? filter = null)
{
DbReadOpsMetric.Inc();
return _db.GetAdminLogMessages(filter);
return RunDbCommand(() => _db.GetAdminLogMessages(filter));
}
public IAsyncEnumerable<SharedAdminLog> GetAdminLogs(LogFilter? filter = null)
{
DbReadOpsMetric.Inc();
return _db.GetAdminLogs(filter);
return RunDbCommand(() => _db.GetAdminLogs(filter));
}
public IAsyncEnumerable<JsonDocument> GetAdminLogsJson(LogFilter? filter = null)
{
DbReadOpsMetric.Inc();
return _db.GetAdminLogsJson(filter);
return RunDbCommand(() => _db.GetAdminLogsJson(filter));
}
public Task<int> CountAdminLogs(int round)
{
DbReadOpsMetric.Inc();
return _db.CountAdminLogs(round);
return RunDbCommand(() => _db.CountAdminLogs(round));
}
public Task<bool> GetWhitelistStatusAsync(NetUserId player)
@@ -857,7 +859,7 @@ namespace Content.Server.Database
private Task<T> RunDbCommand<T>(Func<Task<T>> command)
{
if (_synchronous)
return command();
return RunDbCommandCoreSync(command);
return Task.Run(command);
}
@@ -865,11 +867,57 @@ namespace Content.Server.Database
private Task RunDbCommand(Func<Task> command)
{
if (_synchronous)
return command();
return RunDbCommandCoreSync(command);
return Task.Run(command);
}
private static T RunDbCommandCoreSync<T>(Func<T> command) where T : IAsyncResult
{
var task = command();
if (!task.IsCompleted)
{
// We can't just do BlockWaitOnTask here, because that could cause deadlocks.
// This flag is only intended for integration tests. If we trip this, it's a bug.
throw new InvalidOperationException(
"Database task is running asynchronously. " +
"This should be impossible when the database is set to synchronous.");
}
return task;
}
private IAsyncEnumerable<T> RunDbCommand<T>(Func<IAsyncEnumerable<T>> command)
{
var enumerable = command();
if (!_synchronous)
return enumerable;
// IAsyncEnumerable<T> must be drained synchronously and returned as a fake async enumerable.
// If we were to let it go through like normal, it'd do a bunch of bad async stuff and break everything.
var results = new List<T>();
var enumerator = enumerable.GetAsyncEnumerator();
while (true)
{
var result = enumerator.MoveNextAsync();
if (!result.IsCompleted)
{
throw new InvalidOperationException(
"Database async enumerable is running asynchronously. " +
$"This should be impossible when the database is set to synchronous. Count: {results.Count}");
}
if (!result.Result)
break;
results.Add(enumerator.Current);
}
return new FakeAsyncEnumerable<T>(results);
}
private DbContextOptions<PostgresServerDbContext> CreatePostgresOptions()
{
var host = _cfg.GetCVar(CCVars.DatabasePgHost);
@@ -999,4 +1047,42 @@ namespace Content.Server.Database
}
public sealed record PlayTimeUpdate(NetUserId User, string Tracker, TimeSpan Time);
internal sealed class FakeAsyncEnumerable<T> : IAsyncEnumerable<T>
{
private readonly IEnumerable<T> _enumerable;
public FakeAsyncEnumerable(IEnumerable<T> enumerable)
{
_enumerable = enumerable;
}
public IAsyncEnumerator<T> GetAsyncEnumerator(CancellationToken cancellationToken = default)
{
return new Enumerator(_enumerable.GetEnumerator());
}
private sealed class Enumerator : IAsyncEnumerator<T>
{
private readonly IEnumerator<T> _enumerator;
public Enumerator(IEnumerator<T> enumerator)
{
_enumerator = enumerator;
}
public ValueTask DisposeAsync()
{
_enumerator.Dispose();
return ValueTask.CompletedTask;
}
public ValueTask<bool> MoveNextAsync()
{
return new ValueTask<bool>(_enumerator.MoveNext());
}
public T Current => _enumerator.Current;
}
}
}