diff --git a/Content.Server/Administration/Logs/AdminLogManager.cs b/Content.Server/Administration/Logs/AdminLogManager.cs index c2ecea57a9..f7d2346568 100644 --- a/Content.Server/Administration/Logs/AdminLogManager.cs +++ b/Content.Server/Administration/Logs/AdminLogManager.cs @@ -12,6 +12,7 @@ using Robust.Shared; using Robust.Shared.Configuration; using Robust.Shared.Reflection; using Robust.Shared.Timing; +using Robust.Shared.Utility; namespace Content.Server.Administration.Logs; @@ -196,10 +197,7 @@ public sealed partial class AdminLogManager : SharedAdminLogManager, IAdminLogMa PreRoundQueue.Set(0); // ship the logs to Azkaban - var task = Task.Run(async () => - { - await _db.AddAdminLogs(copy); - }); + var task = _db.AddAdminLogs(copy); _sawmill.Debug($"Saving {copy.Count} admin logs."); diff --git a/Content.Server/Database/ServerDbManager.cs b/Content.Server/Database/ServerDbManager.cs index f5c116ca0a..42d0bc153b 100644 --- a/Content.Server/Database/ServerDbManager.cs +++ b/Content.Server/Database/ServerDbManager.cs @@ -14,6 +14,7 @@ using Microsoft.EntityFrameworkCore; using Microsoft.Extensions.Logging; using Npgsql; using Prometheus; +using Robust.Shared.Asynchronous; using Robust.Shared.Configuration; using Robust.Shared.ContentPack; using Robust.Shared.Network; @@ -291,6 +292,7 @@ namespace Content.Server.Database [Dependency] private readonly IConfigurationManager _cfg = default!; [Dependency] private readonly IResourceManager _res = default!; [Dependency] private readonly ILogManager _logMgr = default!; + [Dependency] private readonly ITaskManager _taskManager = default!; private ServerDbBase _db = default!; private LoggingProvider _msLogProvider = default!; @@ -316,7 +318,7 @@ namespace Content.Server.Database { case "sqlite": SetupSqlite(out var contextFunc, out var inMemory); - _db = new ServerDbSqlite(contextFunc, inMemory, _cfg); + _db = new ServerDbSqlite(contextFunc, inMemory, _cfg, _synchronous); break; case "postgres": var pgOptions = CreatePostgresOptions(); @@ -619,25 +621,25 @@ namespace Content.Server.Database public IAsyncEnumerable GetAdminLogMessages(LogFilter? filter = null) { DbReadOpsMetric.Inc(); - return _db.GetAdminLogMessages(filter); + return RunDbCommand(() => _db.GetAdminLogMessages(filter)); } public IAsyncEnumerable GetAdminLogs(LogFilter? filter = null) { DbReadOpsMetric.Inc(); - return _db.GetAdminLogs(filter); + return RunDbCommand(() => _db.GetAdminLogs(filter)); } public IAsyncEnumerable GetAdminLogsJson(LogFilter? filter = null) { DbReadOpsMetric.Inc(); - return _db.GetAdminLogsJson(filter); + return RunDbCommand(() => _db.GetAdminLogsJson(filter)); } public Task CountAdminLogs(int round) { DbReadOpsMetric.Inc(); - return _db.CountAdminLogs(round); + return RunDbCommand(() => _db.CountAdminLogs(round)); } public Task GetWhitelistStatusAsync(NetUserId player) @@ -857,7 +859,7 @@ namespace Content.Server.Database private Task RunDbCommand(Func> command) { if (_synchronous) - return command(); + return RunDbCommandCoreSync(command); return Task.Run(command); } @@ -865,11 +867,57 @@ namespace Content.Server.Database private Task RunDbCommand(Func command) { if (_synchronous) - return command(); + return RunDbCommandCoreSync(command); return Task.Run(command); } + private static T RunDbCommandCoreSync(Func command) where T : IAsyncResult + { + var task = command(); + if (!task.IsCompleted) + { + // We can't just do BlockWaitOnTask here, because that could cause deadlocks. + // This flag is only intended for integration tests. If we trip this, it's a bug. + throw new InvalidOperationException( + "Database task is running asynchronously. " + + "This should be impossible when the database is set to synchronous."); + } + + return task; + } + + private IAsyncEnumerable RunDbCommand(Func> command) + { + var enumerable = command(); + if (!_synchronous) + return enumerable; + + // IAsyncEnumerable must be drained synchronously and returned as a fake async enumerable. + // If we were to let it go through like normal, it'd do a bunch of bad async stuff and break everything. + + var results = new List(); + var enumerator = enumerable.GetAsyncEnumerator(); + + while (true) + { + var result = enumerator.MoveNextAsync(); + if (!result.IsCompleted) + { + throw new InvalidOperationException( + "Database async enumerable is running asynchronously. " + + $"This should be impossible when the database is set to synchronous. Count: {results.Count}"); + } + + if (!result.Result) + break; + + results.Add(enumerator.Current); + } + + return new FakeAsyncEnumerable(results); + } + private DbContextOptions CreatePostgresOptions() { var host = _cfg.GetCVar(CCVars.DatabasePgHost); @@ -999,4 +1047,42 @@ namespace Content.Server.Database } public sealed record PlayTimeUpdate(NetUserId User, string Tracker, TimeSpan Time); + + internal sealed class FakeAsyncEnumerable : IAsyncEnumerable + { + private readonly IEnumerable _enumerable; + + public FakeAsyncEnumerable(IEnumerable enumerable) + { + _enumerable = enumerable; + } + + public IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) + { + return new Enumerator(_enumerable.GetEnumerator()); + } + + private sealed class Enumerator : IAsyncEnumerator + { + private readonly IEnumerator _enumerator; + + public Enumerator(IEnumerator enumerator) + { + _enumerator = enumerator; + } + + public ValueTask DisposeAsync() + { + _enumerator.Dispose(); + return ValueTask.CompletedTask; + } + + public ValueTask MoveNextAsync() + { + return new ValueTask(_enumerator.MoveNext()); + } + + public T Current => _enumerator.Current; + } + } } diff --git a/Content.Server/Database/ServerDbSqlite.cs b/Content.Server/Database/ServerDbSqlite.cs index 32ae4dc681..781c16ddb9 100644 --- a/Content.Server/Database/ServerDbSqlite.cs +++ b/Content.Server/Database/ServerDbSqlite.cs @@ -20,16 +20,15 @@ namespace Content.Server.Database { private readonly Func> _options; - private readonly SemaphoreSlim _prefsSemaphore; + private readonly ConcurrencySemaphore _prefsSemaphore; private readonly Task _dbReadyTask; private int _msDelay; - public ServerDbSqlite( - Func> options, + public ServerDbSqlite(Func> options, bool inMemory, - IConfigurationManager cfg) + IConfigurationManager cfg, bool synchronous) { _options = options; @@ -37,7 +36,7 @@ namespace Content.Server.Database // When inMemory we re-use the same connection, so we can't have any concurrency. var concurrency = inMemory ? 1 : cfg.GetCVar(CCVars.DatabaseSqliteConcurrency); - _prefsSemaphore = new SemaphoreSlim(concurrency, concurrency); + _prefsSemaphore = new ConcurrencySemaphore(concurrency, synchronous); if (cfg.GetCVar(CCVars.DatabaseSynchronous)) { @@ -564,5 +563,68 @@ namespace Content.Server.Database _db._prefsSemaphore.Release(); } } + + private sealed class ConcurrencySemaphore + { + private readonly bool _synchronous; + private readonly SemaphoreSlim _semaphore; + private Thread? _holdingThread; + + public ConcurrencySemaphore(int maxCount, bool synchronous) + { + if (synchronous && maxCount != 1) + throw new ArgumentException("If synchronous, max concurrency must be 1"); + + _synchronous = synchronous; + _semaphore = new SemaphoreSlim(maxCount, maxCount); + } + + public Task WaitAsync() + { + var task = _semaphore.WaitAsync(); + + if (_synchronous) + { + if (!task.IsCompleted) + { + if (Thread.CurrentThread == _holdingThread) + { + throw new InvalidOperationException( + "Multiple database requests from same thread on synchronous database!"); + } + + throw new InvalidOperationException( + $"Different threads trying to access the database at once! " + + $"Holding thread: {DiagThread(_holdingThread)}, " + + $"current thread: {DiagThread(Thread.CurrentThread)}"); + } + + _holdingThread = Thread.CurrentThread; + } + + return task; + } + + public void Release() + { + if (_synchronous) + { + if (Thread.CurrentThread != _holdingThread) + throw new InvalidOperationException("Released on different thread than took lock???"); + + _holdingThread = null; + } + + _semaphore.Release(); + } + + private static string DiagThread(Thread? thread) + { + if (thread != null) + return $"{thread.Name} ({thread.ManagedThreadId})"; + + return ""; + } + } } } diff --git a/Content.Server/GameTicking/GameTicker.Player.cs b/Content.Server/GameTicking/GameTicker.Player.cs index f18e00e4ab..894d133f4d 100644 --- a/Content.Server/GameTicking/GameTicker.Player.cs +++ b/Content.Server/GameTicking/GameTicker.Player.cs @@ -120,27 +120,9 @@ namespace Content.Server.GameTicking async void SpawnWaitDb() { - // Temporary debugging code to fix a random test failures - var initialStatus = _userDb.GetLoadTask(session).Status; - var prefsLoaded = _prefsManager.HavePreferencesLoaded(session); - DebugTools.Assert(session.Status == SessionStatus.InGame); - await _userDb.WaitLoadComplete(session); - try - { - SpawnPlayer(session, EntityUid.Invalid); - } - catch (Exception e) - { - Log.Error($"Caught exception while trying to spawn a player.\n" + - $"Initial DB task status: {initialStatus}\n" + - $"Prefs initially loaded: {prefsLoaded}\n" + - $"DB task status: {_userDb.GetLoadTask(session).Status}\n" + - $"Prefs loaded: {_prefsManager.HavePreferencesLoaded(session)}\n" + - $"Exception: \n{e}"); - throw; - } + SpawnPlayer(session, EntityUid.Invalid); } async void SpawnObserverWaitDb() diff --git a/Content.Tests/Server/Preferences/ServerDbSqliteTests.cs b/Content.Tests/Server/Preferences/ServerDbSqliteTests.cs index e18964817b..7c3282e201 100644 --- a/Content.Tests/Server/Preferences/ServerDbSqliteTests.cs +++ b/Content.Tests/Server/Preferences/ServerDbSqliteTests.cs @@ -75,7 +75,7 @@ namespace Content.Tests.Server.Preferences var conn = new SqliteConnection("Data Source=:memory:"); conn.Open(); builder.UseSqlite(conn); - return new ServerDbSqlite(() => builder.Options, true, IoCManager.Resolve()); + return new ServerDbSqlite(() => builder.Options, true, IoCManager.Resolve(), true); } [Test]