Kick on ban for entire server group (#28649)

* Start work on PostgresNotificationManager
Implement initial version of init and listening code

* Finish implementing PostgresNotificationManager
Implement ban insert trigger

* Implement ignoring notifications if the ban was from the same server

* Address reviews

* Fixes and refactorings

Fix typo in migration SQL

Pull new code in BanManager out into its own partial file.

Unify logic to kick somebody with that when a new ban is placed directly on the server.

New bans are now checked against all parameters (IP, HWID) instead of just user ID.

Extracted SQLite ban matching code into a new class so that it can mostly be re-used by the ban notification code. No copy-paste here.

Database notifications are now not implicitly sent to the main thread, this means basic checks will happen in the thread pool beforehand.

Bans without user ID are now sent to servers. Bans are rate limited to avoid undue work from mass ban imports, beyond the rate limit they are dropped.

Improved error handling and logging for the whole system.

Matching bans against connected players requires knowing their ban exemption flags. These are now cached when the player connects.

ServerBanDef now has exemption flags, again to allow matching full ban details for ban notifications.

Made database notifications a proper struct type to reduce copy pasting a tuple.

Remove copy pasted connection string building code by just... passing the string into the constructor.

Add lock around _notificationHandlers just in case.

Fixed postgres connection wait not being called in a loop and therefore spamming LISTEN commands for every received notification.

Added more error handling and logging to notification listener.

Removed some copy pasting from SQLite database layer too while I was at it because god forbid we expect anybody else to do all the work in this project.

Sorry Julian

---------

Co-authored-by: Pieter-Jan Briers <pieterjan.briers+git@gmail.com>
This commit is contained in:
Julian Giebel
2024-08-20 23:31:33 +02:00
committed by GitHub
parent 93497e484f
commit df95be1ce5
13 changed files with 2509 additions and 75 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,44 @@
using Microsoft.EntityFrameworkCore.Migrations;
#nullable disable
namespace Content.Server.Database.Migrations.Postgres
{
/// <inheritdoc />
public partial class ban_notify_trigger : Migration
{
/// <inheritdoc />
protected override void Up(MigrationBuilder migrationBuilder)
{
migrationBuilder.Sql("""
create or replace function send_server_ban_notification()
returns trigger as $$
declare
x_server_id integer;
begin
select round.server_id into x_server_id from round where round.round_id = NEW.round_id;
perform pg_notify('ban_notification', json_build_object('ban_id', NEW.server_ban_id, 'server_id', x_server_id)::text);
return NEW;
end;
$$ LANGUAGE plpgsql;
""");
migrationBuilder.Sql("""
create or replace trigger notify_on_server_ban_insert
after insert on server_ban
for each row
execute function send_server_ban_notification();
""");
}
/// <inheritdoc />
protected override void Down(MigrationBuilder migrationBuilder)
{
migrationBuilder.Sql("""
drop trigger notify_on_server_ban_insert on server_ban;
drop function send_server_ban_notification;
""");
}
}
}

View File

@@ -705,6 +705,11 @@ namespace Content.Server.Database
/// Intended for use with residential IP ranges that are often used maliciously. /// Intended for use with residential IP ranges that are often used maliciously.
/// </remarks> /// </remarks>
BlacklistedRange = 1 << 2, BlacklistedRange = 1 << 2,
/// <summary>
/// Represents having all possible exemption flags.
/// </summary>
All = int.MaxValue,
// @formatter:on // @formatter:on
} }

View File

@@ -0,0 +1,123 @@
using System.Text.Json;
using System.Text.Json.Serialization;
using Content.Server.Database;
namespace Content.Server.Administration.Managers;
public sealed partial class BanManager
{
// Responsible for ban notification handling.
// Ban notifications are sent through the database to notify the entire server group that a new ban has been added,
// so that people will get kicked if they are banned on a different server than the one that placed the ban.
//
// Ban notifications are currently sent by a trigger in the database, automatically.
/// <summary>
/// The notification channel used to broadcast information about new bans.
/// </summary>
public const string BanNotificationChannel = "ban_notification";
// Rate limit to avoid undue load from mass-ban imports.
// Only process 10 bans per 30 second interval.
//
// I had the idea of maybe binning this by postgres transaction ID,
// to avoid any possibility of dropping a normal ban by coincidence.
// Didn't bother implementing this though.
private static readonly TimeSpan BanNotificationRateLimitTime = TimeSpan.FromSeconds(30);
private const int BanNotificationRateLimitCount = 10;
private readonly object _banNotificationRateLimitStateLock = new();
private TimeSpan _banNotificationRateLimitStart;
private int _banNotificationRateLimitCount;
private void OnDatabaseNotification(DatabaseNotification notification)
{
if (notification.Channel != BanNotificationChannel)
return;
if (notification.Payload == null)
{
_sawmill.Error("Got ban notification with null payload!");
return;
}
BanNotificationData data;
try
{
data = JsonSerializer.Deserialize<BanNotificationData>(notification.Payload)
?? throw new JsonException("Content is null");
}
catch (JsonException e)
{
_sawmill.Error($"Got invalid JSON in ban notification: {e}");
return;
}
if (!CheckBanRateLimit())
{
_sawmill.Verbose("Not processing ban notification due to rate limit");
return;
}
_taskManager.RunOnMainThread(() => ProcessBanNotification(data));
}
private async void ProcessBanNotification(BanNotificationData data)
{
if ((await _entryManager.ServerEntity).Id == data.ServerId)
{
_sawmill.Verbose("Not processing ban notification: came from this server");
return;
}
_sawmill.Verbose($"Processing ban notification for ban {data.BanId}");
var ban = await _db.GetServerBanAsync(data.BanId);
if (ban == null)
{
_sawmill.Warning($"Ban in notification ({data.BanId}) didn't exist?");
return;
}
KickMatchingConnectedPlayers(ban, "ban notification");
}
private bool CheckBanRateLimit()
{
lock (_banNotificationRateLimitStateLock)
{
var now = _gameTiming.RealTime;
if (_banNotificationRateLimitStart + BanNotificationRateLimitTime < now)
{
// Rate limit period expired, restart it.
_banNotificationRateLimitCount = 1;
_banNotificationRateLimitStart = now;
return true;
}
_banNotificationRateLimitCount += 1;
return _banNotificationRateLimitCount <= BanNotificationRateLimitCount;
}
}
/// <summary>
/// Data sent along the notification channel for a single ban notification.
/// </summary>
private sealed class BanNotificationData
{
/// <summary>
/// The ID of the new ban object in the database to check.
/// </summary>
[JsonRequired, JsonPropertyName("ban_id")]
public int BanId { get; init; }
/// <summary>
/// The id of the server the ban was made on.
/// This is used to avoid double work checking the ban on the originating server.
/// </summary>
/// <remarks>
/// This is optional in case the ban was made outside a server (SS14.Admin)
/// </remarks>
[JsonPropertyName("server_id")]
public int? ServerId { get; init; }
}
}

View File

@@ -2,6 +2,7 @@ using System.Collections.Immutable;
using System.Linq; using System.Linq;
using System.Net; using System.Net;
using System.Text; using System.Text;
using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
using Content.Server.Chat.Managers; using Content.Server.Chat.Managers;
using Content.Server.Database; using Content.Server.Database;
@@ -12,16 +13,18 @@ using Content.Shared.Players;
using Content.Shared.Players.PlayTimeTracking; using Content.Shared.Players.PlayTimeTracking;
using Content.Shared.Roles; using Content.Shared.Roles;
using Robust.Server.Player; using Robust.Server.Player;
using Robust.Shared.Asynchronous;
using Robust.Shared.Configuration; using Robust.Shared.Configuration;
using Robust.Shared.Enums; using Robust.Shared.Enums;
using Robust.Shared.Network; using Robust.Shared.Network;
using Robust.Shared.Player; using Robust.Shared.Player;
using Robust.Shared.Prototypes; using Robust.Shared.Prototypes;
using Robust.Shared.Timing;
using Robust.Shared.Utility; using Robust.Shared.Utility;
namespace Content.Server.Administration.Managers; namespace Content.Server.Administration.Managers;
public sealed class BanManager : IBanManager, IPostInjectInit public sealed partial class BanManager : IBanManager, IPostInjectInit
{ {
[Dependency] private readonly IServerDbManager _db = default!; [Dependency] private readonly IServerDbManager _db = default!;
[Dependency] private readonly IPlayerManager _playerManager = default!; [Dependency] private readonly IPlayerManager _playerManager = default!;
@@ -29,9 +32,13 @@ public sealed class BanManager : IBanManager, IPostInjectInit
[Dependency] private readonly IEntitySystemManager _systems = default!; [Dependency] private readonly IEntitySystemManager _systems = default!;
[Dependency] private readonly IConfigurationManager _cfg = default!; [Dependency] private readonly IConfigurationManager _cfg = default!;
[Dependency] private readonly ILocalizationManager _localizationManager = default!; [Dependency] private readonly ILocalizationManager _localizationManager = default!;
[Dependency] private readonly ServerDbEntryManager _entryManager = default!;
[Dependency] private readonly IChatManager _chat = default!; [Dependency] private readonly IChatManager _chat = default!;
[Dependency] private readonly INetManager _netManager = default!; [Dependency] private readonly INetManager _netManager = default!;
[Dependency] private readonly ILogManager _logManager = default!; [Dependency] private readonly ILogManager _logManager = default!;
[Dependency] private readonly IGameTiming _gameTiming = default!;
[Dependency] private readonly ITaskManager _taskManager = default!;
[Dependency] private readonly UserDbDataManager _userDbData = default!;
private ISawmill _sawmill = default!; private ISawmill _sawmill = default!;
@@ -39,12 +46,34 @@ public sealed class BanManager : IBanManager, IPostInjectInit
public const string JobPrefix = "Job:"; public const string JobPrefix = "Job:";
private readonly Dictionary<NetUserId, HashSet<ServerRoleBanDef>> _cachedRoleBans = new(); private readonly Dictionary<NetUserId, HashSet<ServerRoleBanDef>> _cachedRoleBans = new();
// Cached ban exemption flags are used to handle
private readonly Dictionary<ICommonSession, ServerBanExemptFlags> _cachedBanExemptions = new();
public void Initialize() public void Initialize()
{ {
_playerManager.PlayerStatusChanged += OnPlayerStatusChanged; _playerManager.PlayerStatusChanged += OnPlayerStatusChanged;
_netManager.RegisterNetMessage<MsgRoleBans>(); _netManager.RegisterNetMessage<MsgRoleBans>();
_db.SubscribeToNotifications(OnDatabaseNotification);
_userDbData.AddOnLoadPlayer(CachePlayerData);
_userDbData.AddOnPlayerDisconnect(ClearPlayerData);
}
private async Task CachePlayerData(ICommonSession player, CancellationToken cancel)
{
// Yeah so role ban loading code isn't integrated with exempt flag loading code.
// Have you seen how garbage role ban code code is? I don't feel like refactoring it right now.
var flags = await _db.GetBanExemption(player.UserId, cancel);
cancel.ThrowIfCancellationRequested();
_cachedBanExemptions[player] = flags;
}
private void ClearPlayerData(ICommonSession player)
{
_cachedBanExemptions.Remove(player);
} }
private async void OnPlayerStatusChanged(object? sender, SessionStatusEventArgs e) private async void OnPlayerStatusChanged(object? sender, SessionStatusEventArgs e)
@@ -168,17 +197,43 @@ public sealed class BanManager : IBanManager, IPostInjectInit
_sawmill.Info(logMessage); _sawmill.Info(logMessage);
_chat.SendAdminAlert(logMessage); _chat.SendAdminAlert(logMessage);
// If we're not banning a player we don't care about disconnecting people KickMatchingConnectedPlayers(banDef, "newly placed ban");
if (target == null)
return;
// Is the player connected?
if (!_playerManager.TryGetSessionById(target.Value, out var targetPlayer))
return;
// If they are, kick them
var message = banDef.FormatBanMessage(_cfg, _localizationManager);
targetPlayer.Channel.Disconnect(message);
} }
private void KickMatchingConnectedPlayers(ServerBanDef def, string source)
{
foreach (var player in _playerManager.Sessions)
{
if (BanMatchesPlayer(player, def))
{
KickForBanDef(player, def);
_sawmill.Info($"Kicked player {player.Name} ({player.UserId}) through {source}");
}
}
}
private bool BanMatchesPlayer(ICommonSession player, ServerBanDef ban)
{
var playerInfo = new BanMatcher.PlayerInfo
{
UserId = player.UserId,
Address = player.Channel.RemoteEndPoint.Address,
HWId = player.Channel.UserData.HWId,
// It's possible for the player to not have cached data loading yet due to coincidental timing.
// If this is the case, we assume they have all flags to avoid false-positives.
ExemptFlags = _cachedBanExemptions.GetValueOrDefault(player, ServerBanExemptFlags.All),
IsNewPlayer = false,
};
return BanMatcher.BanMatches(ban, playerInfo);
}
private void KickForBanDef(ICommonSession player, ServerBanDef def)
{
var message = def.FormatBanMessage(_cfg, _localizationManager);
player.Channel.Disconnect(message);
}
#endregion #endregion
#region Job Bans #region Job Bans

View File

@@ -0,0 +1,90 @@
using System.Collections.Immutable;
using System.Net;
using Content.Server.IP;
using Robust.Shared.Network;
namespace Content.Server.Database;
/// <summary>
/// Implements logic to match a <see cref="ServerBanDef"/> against a player query.
/// </summary>
/// <remarks>
/// <para>
/// This implementation is used by in-game ban matching code, and partially by the SQLite database layer.
/// Some logic is duplicated into both the SQLite and PostgreSQL database layers to provide more optimal SQL queries.
/// Both should be kept in sync, please!
/// </para>
/// </remarks>
public static class BanMatcher
{
/// <summary>
/// Check whether a ban matches the specified player info.
/// </summary>
/// <remarks>
/// <para>
/// This function does not check whether the ban itself is expired or manually unbanned.
/// </para>
/// </remarks>
/// <param name="ban">The ban information.</param>
/// <param name="player">Information about the player to match against.</param>
/// <returns>True if the ban matches the provided player info.</returns>
public static bool BanMatches(ServerBanDef ban, in PlayerInfo player)
{
var exemptFlags = player.ExemptFlags;
// Any flag to bypass BlacklistedRange bans.
if (exemptFlags != ServerBanExemptFlags.None)
exemptFlags |= ServerBanExemptFlags.BlacklistedRange;
if ((ban.ExemptFlags & exemptFlags) != 0)
return false;
if (!player.ExemptFlags.HasFlag(ServerBanExemptFlags.IP)
&& player.Address != null
&& ban.Address is not null
&& player.Address.IsInSubnet(ban.Address.Value)
&& (!ban.ExemptFlags.HasFlag(ServerBanExemptFlags.BlacklistedRange) || player.IsNewPlayer))
{
return true;
}
if (player.UserId is { } id && ban.UserId == id.UserId)
{
return true;
}
return player.HWId is { Length: > 0 } hwIdVar
&& ban.HWId != null
&& hwIdVar.AsSpan().SequenceEqual(ban.HWId.Value.AsSpan());
}
/// <summary>
/// A simple struct containing player info used to match bans against.
/// </summary>
public struct PlayerInfo
{
/// <summary>
/// The user ID of the player.
/// </summary>
public NetUserId? UserId;
/// <summary>
/// The IP address of the player.
/// </summary>
public IPAddress? Address;
/// <summary>
/// The hardware ID of the player.
/// </summary>
public ImmutableArray<byte>? HWId;
/// <summary>
/// Exemption flags the player has been granted.
/// </summary>
public ServerBanExemptFlags ExemptFlags;
/// <summary>
/// True if this player is new and is thus eligible for more bans.
/// </summary>
public bool IsNewPlayer;
}
}

View File

@@ -23,9 +23,9 @@ namespace Content.Server.Database
public NoteSeverity Severity { get; set; } public NoteSeverity Severity { get; set; }
public NetUserId? BanningAdmin { get; } public NetUserId? BanningAdmin { get; }
public ServerUnbanDef? Unban { get; } public ServerUnbanDef? Unban { get; }
public ServerBanExemptFlags ExemptFlags { get; }
public ServerBanDef( public ServerBanDef(int? id,
int? id,
NetUserId? userId, NetUserId? userId,
(IPAddress, int)? address, (IPAddress, int)? address,
ImmutableArray<byte>? hwId, ImmutableArray<byte>? hwId,
@@ -36,7 +36,8 @@ namespace Content.Server.Database
string reason, string reason,
NoteSeverity severity, NoteSeverity severity,
NetUserId? banningAdmin, NetUserId? banningAdmin,
ServerUnbanDef? unban) ServerUnbanDef? unban,
ServerBanExemptFlags exemptFlags = default)
{ {
if (userId == null && address == null && hwId == null) if (userId == null && address == null && hwId == null)
{ {
@@ -62,6 +63,7 @@ namespace Content.Server.Database
Severity = severity; Severity = severity;
BanningAdmin = banningAdmin; BanningAdmin = banningAdmin;
Unban = unban; Unban = unban;
ExemptFlags = exemptFlags;
} }
public string FormatBanMessage(IConfigurationManager cfg, ILocalizationManager loc) public string FormatBanMessage(IConfigurationManager cfg, ILocalizationManager loc)

View File

@@ -28,6 +28,8 @@ namespace Content.Server.Database
{ {
private readonly ISawmill _opsLog; private readonly ISawmill _opsLog;
public event Action<DatabaseNotification>? OnNotificationReceived;
/// <param name="opsLog">Sawmill to trace log database operations to.</param> /// <param name="opsLog">Sawmill to trace log database operations to.</param>
public ServerDbBase(ISawmill opsLog) public ServerDbBase(ISawmill opsLog)
{ {
@@ -425,13 +427,16 @@ namespace Content.Server.Database
await db.DbContext.SaveChangesAsync(); await db.DbContext.SaveChangesAsync();
} }
protected static async Task<ServerBanExemptFlags?> GetBanExemptionCore(DbGuard db, NetUserId? userId) protected static async Task<ServerBanExemptFlags?> GetBanExemptionCore(
DbGuard db,
NetUserId? userId,
CancellationToken cancel = default)
{ {
if (userId == null) if (userId == null)
return null; return null;
var exemption = await db.DbContext.BanExemption var exemption = await db.DbContext.BanExemption
.SingleOrDefaultAsync(e => e.UserId == userId.Value.UserId); .SingleOrDefaultAsync(e => e.UserId == userId.Value.UserId, cancellationToken: cancel);
return exemption?.Flags; return exemption?.Flags;
} }
@@ -462,11 +467,11 @@ namespace Content.Server.Database
await db.DbContext.SaveChangesAsync(); await db.DbContext.SaveChangesAsync();
} }
public async Task<ServerBanExemptFlags> GetBanExemption(NetUserId userId) public async Task<ServerBanExemptFlags> GetBanExemption(NetUserId userId, CancellationToken cancel)
{ {
await using var db = await GetDb(); await using var db = await GetDb(cancel);
var flags = await GetBanExemptionCore(db, userId); var flags = await GetBanExemptionCore(db, userId, cancel);
return flags ?? ServerBanExemptFlags.None; return flags ?? ServerBanExemptFlags.None;
} }
@@ -1677,5 +1682,15 @@ INSERT INTO player_round (players_id, rounds_id) VALUES ({players[player]}, {id}
public abstract ValueTask DisposeAsync(); public abstract ValueTask DisposeAsync();
} }
protected void NotificationReceived(DatabaseNotification notification)
{
OnNotificationReceived?.Invoke(notification);
}
public virtual void Shutdown()
{
}
} }
} }

View File

@@ -116,7 +116,7 @@ namespace Content.Server.Database
/// Get current ban exemption flags for a user /// Get current ban exemption flags for a user
/// </summary> /// </summary>
/// <returns><see cref="ServerBanExemptFlags.None"/> if the user is not exempt from any bans.</returns> /// <returns><see cref="ServerBanExemptFlags.None"/> if the user is not exempt from any bans.</returns>
Task<ServerBanExemptFlags> GetBanExemption(NetUserId userId); Task<ServerBanExemptFlags> GetBanExemption(NetUserId userId, CancellationToken cancel = default);
#endregion #endregion
@@ -304,6 +304,43 @@ namespace Content.Server.Database
Task<bool> RemoveJobWhitelist(Guid player, ProtoId<JobPrototype> job); Task<bool> RemoveJobWhitelist(Guid player, ProtoId<JobPrototype> job);
#endregion #endregion
#region DB Notifications
void SubscribeToNotifications(Action<DatabaseNotification> handler);
/// <summary>
/// Inject a notification as if it was created by the database. This is intended for testing.
/// </summary>
/// <param name="notification">The notification to trigger</param>
void InjectTestNotification(DatabaseNotification notification);
#endregion
}
/// <summary>
/// Represents a notification sent between servers via the database layer.
/// </summary>
/// <remarks>
/// <para>
/// Database notifications are a simple system to broadcast messages to an entire server group
/// backed by the same database. For example, this is used to notify all servers of new ban records.
/// </para>
/// <para>
/// They are currently implemented by the PostgreSQL <c>NOTIFY</c> and <c>LISTEN</c> commands.
/// </para>
/// </remarks>
public struct DatabaseNotification
{
/// <summary>
/// The channel for the notification. This can be used to differentiate notifications for different purposes.
/// </summary>
public required string Channel { get; set; }
/// <summary>
/// The actual contents of the notification. Optional.
/// </summary>
public string? Payload { get; set; }
} }
public sealed class ServerDbManager : IServerDbManager public sealed class ServerDbManager : IServerDbManager
@@ -333,6 +370,8 @@ namespace Content.Server.Database
// This is that connection, close it when we shut down. // This is that connection, close it when we shut down.
private SqliteConnection? _sqliteInMemoryConnection; private SqliteConnection? _sqliteInMemoryConnection;
private readonly List<Action<DatabaseNotification>> _notificationHandlers = [];
public void Init() public void Init()
{ {
_msLogProvider = new LoggingProvider(_logMgr); _msLogProvider = new LoggingProvider(_logMgr);
@@ -345,6 +384,7 @@ namespace Content.Server.Database
var engine = _cfg.GetCVar(CCVars.DatabaseEngine).ToLower(); var engine = _cfg.GetCVar(CCVars.DatabaseEngine).ToLower();
var opsLog = _logMgr.GetSawmill("db.op"); var opsLog = _logMgr.GetSawmill("db.op");
var notifyLog = _logMgr.GetSawmill("db.notify");
switch (engine) switch (engine)
{ {
case "sqlite": case "sqlite":
@@ -352,17 +392,22 @@ namespace Content.Server.Database
_db = new ServerDbSqlite(contextFunc, inMemory, _cfg, _synchronous, opsLog); _db = new ServerDbSqlite(contextFunc, inMemory, _cfg, _synchronous, opsLog);
break; break;
case "postgres": case "postgres":
var pgOptions = CreatePostgresOptions(); var (pgOptions, conString) = CreatePostgresOptions();
_db = new ServerDbPostgres(pgOptions, _cfg, opsLog); _db = new ServerDbPostgres(pgOptions, conString, _cfg, opsLog, notifyLog);
break; break;
default: default:
throw new InvalidDataException($"Unknown database engine {engine}."); throw new InvalidDataException($"Unknown database engine {engine}.");
} }
_db.OnNotificationReceived += HandleDatabaseNotification;
} }
public void Shutdown() public void Shutdown()
{ {
_db.OnNotificationReceived -= HandleDatabaseNotification;
_sqliteInMemoryConnection?.Dispose(); _sqliteInMemoryConnection?.Dispose();
_db.Shutdown();
} }
public Task<PlayerPreferences> InitPrefsAsync( public Task<PlayerPreferences> InitPrefsAsync(
@@ -465,10 +510,10 @@ namespace Content.Server.Database
return RunDbCommand(() => _db.UpdateBanExemption(userId, flags)); return RunDbCommand(() => _db.UpdateBanExemption(userId, flags));
} }
public Task<ServerBanExemptFlags> GetBanExemption(NetUserId userId) public Task<ServerBanExemptFlags> GetBanExemption(NetUserId userId, CancellationToken cancel = default)
{ {
DbReadOpsMetric.Inc(); DbReadOpsMetric.Inc();
return RunDbCommand(() => _db.GetBanExemption(userId)); return RunDbCommand(() => _db.GetBanExemption(userId, cancel));
} }
#region Role Ban #region Role Ban
@@ -907,6 +952,30 @@ namespace Content.Server.Database
return RunDbCommand(() => _db.RemoveJobWhitelist(player, job)); return RunDbCommand(() => _db.RemoveJobWhitelist(player, job));
} }
public void SubscribeToNotifications(Action<DatabaseNotification> handler)
{
lock (_notificationHandlers)
{
_notificationHandlers.Add(handler);
}
}
public void InjectTestNotification(DatabaseNotification notification)
{
HandleDatabaseNotification(notification);
}
private async void HandleDatabaseNotification(DatabaseNotification notification)
{
lock (_notificationHandlers)
{
foreach (var handler in _notificationHandlers)
{
handler(notification);
}
}
}
// Wrapper functions to run DB commands from the thread pool. // Wrapper functions to run DB commands from the thread pool.
// This will avoid SynchronizationContext capturing and avoid running CPU work on the main thread. // This will avoid SynchronizationContext capturing and avoid running CPU work on the main thread.
// For SQLite, this will also enable read parallelization (within limits). // For SQLite, this will also enable read parallelization (within limits).
@@ -962,7 +1031,7 @@ namespace Content.Server.Database
return enumerable; return enumerable;
} }
private DbContextOptions<PostgresServerDbContext> CreatePostgresOptions() private (DbContextOptions<PostgresServerDbContext> options, string connectionString) CreatePostgresOptions()
{ {
var host = _cfg.GetCVar(CCVars.DatabasePgHost); var host = _cfg.GetCVar(CCVars.DatabasePgHost);
var port = _cfg.GetCVar(CCVars.DatabasePgPort); var port = _cfg.GetCVar(CCVars.DatabasePgPort);
@@ -984,7 +1053,7 @@ namespace Content.Server.Database
builder.UseNpgsql(connectionString); builder.UseNpgsql(connectionString);
SetupLogging(builder); SetupLogging(builder);
return builder.Options; return (builder.Options, connectionString);
} }
private void SetupSqlite(out Func<DbContextOptions<SqliteServerDbContext>> contextFunc, out bool inMemory) private void SetupSqlite(out Func<DbContextOptions<SqliteServerDbContext>> contextFunc, out bool inMemory)

View File

@@ -0,0 +1,121 @@
using System.Data;
using System.Threading;
using System.Threading.Tasks;
using Content.Server.Administration.Managers;
using Npgsql;
namespace Content.Server.Database;
/// Listens for ban_notification containing the player id and the banning server id using postgres listen/notify.
/// Players a ban_notification got received for get banned, except when the current server id and the one in the notification payload match.
public sealed partial class ServerDbPostgres
{
/// <summary>
/// The list of notify channels to subscribe to.
/// </summary>
private static readonly string[] NotificationChannels =
[
BanManager.BanNotificationChannel,
];
private static readonly TimeSpan ReconnectWaitIncrease = TimeSpan.FromSeconds(10);
private readonly CancellationTokenSource _notificationTokenSource = new();
private NpgsqlConnection? _notificationConnection;
private TimeSpan _reconnectWaitTime = TimeSpan.Zero;
/// <summary>
/// Sets up the database connection and the notification handler
/// </summary>
private void InitNotificationListener(string connectionString)
{
_notificationConnection = new NpgsqlConnection(connectionString);
_notificationConnection.Notification += OnNotification;
var cancellationToken = _notificationTokenSource.Token;
Task.Run(() => NotificationListener(cancellationToken), cancellationToken);
}
/// <summary>
/// Listens to the notification channel with basic error handling and reopens the connection if it got closed
/// </summary>
private async Task NotificationListener(CancellationToken cancellationToken)
{
if (_notificationConnection == null)
return;
_notifyLog.Verbose("Starting notification listener");
while (!cancellationToken.IsCancellationRequested)
{
try
{
if (_notificationConnection.State == ConnectionState.Broken)
{
_notifyLog.Debug("Notification listener entered broken state, closing...");
await _notificationConnection.CloseAsync();
}
if (_notificationConnection.State == ConnectionState.Closed)
{
_notifyLog.Debug("Opening notification listener connection...");
if (_reconnectWaitTime != TimeSpan.Zero)
{
_notifyLog.Verbose($"_reconnectWaitTime is {_reconnectWaitTime}");
await Task.Delay(_reconnectWaitTime, cancellationToken);
}
await _notificationConnection.OpenAsync(cancellationToken);
_reconnectWaitTime = TimeSpan.Zero;
_notifyLog.Verbose($"Notification connection opened...");
}
foreach (var channel in NotificationChannels)
{
_notifyLog.Verbose($"Listening on channel {channel}");
await using var cmd = new NpgsqlCommand($"LISTEN {channel}", _notificationConnection);
await cmd.ExecuteNonQueryAsync(cancellationToken);
}
while (!cancellationToken.IsCancellationRequested)
{
_notifyLog.Verbose("Waiting on notifications...");
await _notificationConnection.WaitAsync(cancellationToken);
}
}
catch (OperationCanceledException)
{
// Abort loop on cancel.
_notifyLog.Verbose($"Shutting down notification listener due to cancellation");
return;
}
catch (Exception e)
{
_reconnectWaitTime += ReconnectWaitIncrease;
_notifyLog.Error($"Error in notification listener: {e}");
}
}
}
private void OnNotification(object _, NpgsqlNotificationEventArgs notification)
{
_notifyLog.Verbose($"Received notification on channel {notification.Channel}");
NotificationReceived(new DatabaseNotification
{
Channel = notification.Channel,
Payload = notification.Payload,
});
}
public override void Shutdown()
{
_notificationTokenSource.Cancel();
if (_notificationConnection == null)
return;
_notificationConnection.Notification -= OnNotification;
_notificationConnection.Dispose();
}
}

View File

@@ -16,23 +16,26 @@ using Robust.Shared.Utility;
namespace Content.Server.Database namespace Content.Server.Database
{ {
public sealed class ServerDbPostgres : ServerDbBase public sealed partial class ServerDbPostgres : ServerDbBase
{ {
private readonly DbContextOptions<PostgresServerDbContext> _options; private readonly DbContextOptions<PostgresServerDbContext> _options;
private readonly ISawmill _notifyLog;
private readonly SemaphoreSlim _prefsSemaphore; private readonly SemaphoreSlim _prefsSemaphore;
private readonly Task _dbReadyTask; private readonly Task _dbReadyTask;
private int _msLag; private int _msLag;
public ServerDbPostgres( public ServerDbPostgres(DbContextOptions<PostgresServerDbContext> options,
DbContextOptions<PostgresServerDbContext> options, string connectionString,
IConfigurationManager cfg, IConfigurationManager cfg,
ISawmill opsLog) ISawmill opsLog,
ISawmill notifyLog)
: base(opsLog) : base(opsLog)
{ {
var concurrency = cfg.GetCVar(CCVars.DatabasePgConcurrency); var concurrency = cfg.GetCVar(CCVars.DatabasePgConcurrency);
_options = options; _options = options;
_notifyLog = notifyLog;
_prefsSemaphore = new SemaphoreSlim(concurrency, concurrency); _prefsSemaphore = new SemaphoreSlim(concurrency, concurrency);
_dbReadyTask = Task.Run(async () => _dbReadyTask = Task.Run(async () =>
@@ -49,6 +52,8 @@ namespace Content.Server.Database
}); });
cfg.OnValueChanged(CCVars.DatabasePgFakeLag, v => _msLag = v, true); cfg.OnValueChanged(CCVars.DatabasePgFakeLag, v => _msLag = v, true);
InitNotificationListener(connectionString);
} }
#region Ban #region Ban
@@ -214,7 +219,8 @@ namespace Content.Server.Database
ban.Reason, ban.Reason,
ban.Severity, ban.Severity,
aUid, aUid,
unbanDef); unbanDef,
ban.ExemptFlags);
} }
private static ServerUnbanDef? ConvertUnban(ServerUnban? unban) private static ServerUnbanDef? ConvertUnban(ServerUnban? unban)
@@ -251,7 +257,8 @@ namespace Content.Server.Database
ExpirationTime = serverBan.ExpirationTime?.UtcDateTime, ExpirationTime = serverBan.ExpirationTime?.UtcDateTime,
RoundId = serverBan.RoundId, RoundId = serverBan.RoundId,
PlaytimeAtNote = serverBan.PlaytimeAtNote, PlaytimeAtNote = serverBan.PlaytimeAtNote,
PlayerUserId = serverBan.UserId?.UserId PlayerUserId = serverBan.UserId?.UserId,
ExemptFlags = serverBan.ExemptFlags
}); });
await db.PgDbContext.SaveChangesAsync(); await db.PgDbContext.SaveChangesAsync();

View File

@@ -84,25 +84,27 @@ namespace Content.Server.Database
{ {
await using var db = await GetDbImpl(); await using var db = await GetDbImpl();
var exempt = await GetBanExemptionCore(db, userId); return (await GetServerBanQueryAsync(db, address, userId, hwId, includeUnbanned: false)).FirstOrDefault();
var newPlayer = userId == null || !await PlayerRecordExists(db, userId.Value);
// SQLite can't do the net masking stuff we need to match IP address ranges.
// So just pull down the whole list into memory.
var bans = await GetAllBans(db.SqliteDbContext, includeUnbanned: false, exempt);
return bans.FirstOrDefault(b => BanMatches(b, address, userId, hwId, exempt, newPlayer)) is { } foundBan
? ConvertBan(foundBan)
: null;
} }
public override async Task<List<ServerBanDef>> GetServerBansAsync(IPAddress? address, public override async Task<List<ServerBanDef>> GetServerBansAsync(
IPAddress? address,
NetUserId? userId, NetUserId? userId,
ImmutableArray<byte>? hwId, bool includeUnbanned) ImmutableArray<byte>? hwId,
bool includeUnbanned)
{ {
await using var db = await GetDbImpl(); await using var db = await GetDbImpl();
return (await GetServerBanQueryAsync(db, address, userId, hwId, includeUnbanned)).ToList();
}
private async Task<IEnumerable<ServerBanDef>> GetServerBanQueryAsync(
DbGuardImpl db,
IPAddress? address,
NetUserId? userId,
ImmutableArray<byte>? hwId,
bool includeUnbanned)
{
var exempt = await GetBanExemptionCore(db, userId); var exempt = await GetBanExemptionCore(db, userId);
var newPlayer = !await db.SqliteDbContext.Player.AnyAsync(p => p.UserId == userId); var newPlayer = !await db.SqliteDbContext.Player.AnyAsync(p => p.UserId == userId);
@@ -111,10 +113,18 @@ namespace Content.Server.Database
// So just pull down the whole list into memory. // So just pull down the whole list into memory.
var queryBans = await GetAllBans(db.SqliteDbContext, includeUnbanned, exempt); var queryBans = await GetAllBans(db.SqliteDbContext, includeUnbanned, exempt);
var playerInfo = new BanMatcher.PlayerInfo
{
Address = address,
UserId = userId,
ExemptFlags = exempt ?? default,
HWId = hwId,
IsNewPlayer = newPlayer,
};
return queryBans return queryBans
.Where(b => BanMatches(b, address, userId, hwId, exempt, newPlayer))
.Select(ConvertBan) .Select(ConvertBan)
.ToList()!; .Where(b => BanMatcher.BanMatches(b!, playerInfo))!;
} }
private static async Task<List<ServerBan>> GetAllBans( private static async Task<List<ServerBan>> GetAllBans(
@@ -141,31 +151,6 @@ namespace Content.Server.Database
return await query.ToListAsync(); return await query.ToListAsync();
} }
private static bool BanMatches(ServerBan ban,
IPAddress? address,
NetUserId? userId,
ImmutableArray<byte>? hwId,
ServerBanExemptFlags? exemptFlags,
bool newPlayer)
{
if (!exemptFlags.GetValueOrDefault(ServerBanExemptFlags.None).HasFlag(ServerBanExemptFlags.IP)
&& address != null
&& ban.Address is not null
&& address.IsInSubnet(ban.Address.ToTuple().Value)
&& (!ban.ExemptFlags.HasFlag(ServerBanExemptFlags.BlacklistedRange) ||
newPlayer))
{
return true;
}
if (userId is { } id && ban.PlayerUserId == id.UserId)
{
return true;
}
return hwId is { Length: > 0 } hwIdVar && hwIdVar.AsSpan().SequenceEqual(ban.HWId);
}
public override async Task AddServerBanAsync(ServerBanDef serverBan) public override async Task AddServerBanAsync(ServerBanDef serverBan)
{ {
await using var db = await GetDbImpl(); await using var db = await GetDbImpl();
@@ -181,7 +166,8 @@ namespace Content.Server.Database
ExpirationTime = serverBan.ExpirationTime?.UtcDateTime, ExpirationTime = serverBan.ExpirationTime?.UtcDateTime,
RoundId = serverBan.RoundId, RoundId = serverBan.RoundId,
PlaytimeAtNote = serverBan.PlaytimeAtNote, PlaytimeAtNote = serverBan.PlaytimeAtNote,
PlayerUserId = serverBan.UserId?.UserId PlayerUserId = serverBan.UserId?.UserId,
ExemptFlags = serverBan.ExemptFlags
}); });
await db.SqliteDbContext.SaveChangesAsync(); await db.SqliteDbContext.SaveChangesAsync();
@@ -364,6 +350,7 @@ namespace Content.Server.Database
} }
#endregion #endregion
[return: NotNullIfNotNull(nameof(ban))]
private static ServerBanDef? ConvertBan(ServerBan? ban) private static ServerBanDef? ConvertBan(ServerBan? ban)
{ {
if (ban == null) if (ban == null)

View File

@@ -82,3 +82,6 @@ ban-panel-erase = Erase chat messages and player from round
server-ban-string = {$admin} created a {$severity} severity server ban that expires {$expires} for [{$name}, {$ip}, {$hwid}], with reason: {$reason} server-ban-string = {$admin} created a {$severity} severity server ban that expires {$expires} for [{$name}, {$ip}, {$hwid}], with reason: {$reason}
server-ban-string-no-pii = {$admin} created a {$severity} severity server ban that expires {$expires} for {$name} with reason: {$reason} server-ban-string-no-pii = {$admin} created a {$severity} severity server ban that expires {$expires} for {$name} with reason: {$reason}
server-ban-string-never = never server-ban-string-never = never
# Kick on ban
ban-kick-reason = You have been banned