Add EntityWhitelistSystem (#27632)

* Add EntityWhitelistSystem

* Sandbox fix

* update test
This commit is contained in:
Leon Friedrich
2024-05-03 12:10:15 +12:00
committed by GitHub
parent 291ecf9643
commit f348e6aa30
4 changed files with 117 additions and 89 deletions

View File

@@ -64,7 +64,8 @@ namespace Content.IntegrationTests.Tests.Utility
var testMap = await pair.CreateTestMap(); var testMap = await pair.CreateTestMap();
var mapCoordinates = testMap.MapCoords; var mapCoordinates = testMap.MapCoords;
var sEntities = server.ResolveDependency<IEntityManager>(); var sEntities = server.EntMan;
var sys = server.System<EntityWhitelistSystem>();
await server.WaitAssertion(() => await server.WaitAssertion(() =>
{ {
@@ -80,22 +81,14 @@ namespace Content.IntegrationTests.Tests.Utility
Components = new[] { $"{ValidComponent}" }, Components = new[] { $"{ValidComponent}" },
Tags = new() { "WhitelistTestValidTag" } Tags = new() { "WhitelistTestValidTag" }
}; };
whitelistInst.UpdateRegistrations();
Assert.That(whitelistInst, Is.Not.Null);
Assert.Multiple(() => Assert.Multiple(() =>
{ {
Assert.That(whitelistInst.Components, Is.Not.Null); Assert.That(sys.IsValid(whitelistInst, validComponent), Is.True);
Assert.That(whitelistInst.Tags, Is.Not.Null); Assert.That(sys.IsValid(whitelistInst, WhitelistTestValidTag), Is.True);
});
Assert.Multiple(() => Assert.That(sys.IsValid(whitelistInst, invalidComponent), Is.False);
{ Assert.That(sys.IsValid(whitelistInst, WhitelistTestInvalidTag), Is.False);
Assert.That(whitelistInst.IsValid(validComponent), Is.True);
Assert.That(whitelistInst.IsValid(WhitelistTestValidTag), Is.True);
Assert.That(whitelistInst.IsValid(invalidComponent), Is.False);
Assert.That(whitelistInst.IsValid(WhitelistTestInvalidTag), Is.False);
}); });
// Test from serialized // Test from serialized
@@ -111,11 +104,11 @@ namespace Content.IntegrationTests.Tests.Utility
Assert.Multiple(() => Assert.Multiple(() =>
{ {
Assert.That(whitelistSer.IsValid(validComponent), Is.True); Assert.That(sys.IsValid(whitelistSer, validComponent), Is.True);
Assert.That(whitelistSer.IsValid(WhitelistTestValidTag), Is.True); Assert.That(sys.IsValid(whitelistSer, WhitelistTestValidTag), Is.True);
Assert.That(whitelistSer.IsValid(invalidComponent), Is.False); Assert.That(sys.IsValid(whitelistSer, invalidComponent), Is.False);
Assert.That(whitelistSer.IsValid(WhitelistTestInvalidTag), Is.False); Assert.That(sys.IsValid(whitelistSer, WhitelistTestInvalidTag), Is.False);
}); });
}); });
await pair.CleanReturnAsync(); await pair.CleanReturnAsync();

View File

@@ -1,3 +1,4 @@
using System.Diagnostics;
using System.Linq; using System.Linq;
using Robust.Shared.GameStates; using Robust.Shared.GameStates;
using Robust.Shared.Prototypes; using Robust.Shared.Prototypes;
@@ -185,6 +186,7 @@ public sealed class TagSystem : EntitySystem
/// <summary> /// <summary>
/// Checks if a tag has been added to an entity. /// Checks if a tag has been added to an entity.
/// </summary> /// </summary>
[Obsolete]
public bool HasTag(EntityUid entity, string id, EntityQuery<TagComponent> tagQuery) public bool HasTag(EntityUid entity, string id, EntityQuery<TagComponent> tagQuery)
{ {
return tagQuery.TryGetComponent(entity, out var component) && return tagQuery.TryGetComponent(entity, out var component) &&
@@ -243,7 +245,7 @@ public sealed class TagSystem : EntitySystem
/// </exception> /// </exception>
public bool HasAllTags(EntityUid entity, List<ProtoId<TagPrototype>> ids) public bool HasAllTags(EntityUid entity, List<ProtoId<TagPrototype>> ids)
{ {
return TryComp<TagComponent>(entity, out var component) && return _tagQuery.TryComp(entity, out var component) &&
HasAllTags(component, ids); HasAllTags(component, ids);
} }
@@ -521,16 +523,18 @@ public sealed class TagSystem : EntitySystem
/// </exception> /// </exception>
public bool HasAllTags(TagComponent component, List<ProtoId<TagPrototype>> ids) public bool HasAllTags(TagComponent component, List<ProtoId<TagPrototype>> ids)
{ {
var stringIds = new List<string>(); foreach (var id in ids)
foreach (var tag in ids)
{ {
stringIds.Add(tag.Id); AssertValidTag(id);
if (!component.Tags.Contains(id))
return false;
} }
return HasAllTags(component, stringIds); return true;
} }
/// <summary> /// <summary>
/// Checks if any of the given tags have been added. /// Checks if any of the given tags have been added.
/// </summary> /// </summary>
@@ -552,7 +556,6 @@ public sealed class TagSystem : EntitySystem
return false; return false;
} }
/// <summary> /// <summary>
/// Checks if any of the given tags have been added. /// Checks if any of the given tags have been added.
/// </summary> /// </summary>
@@ -619,13 +622,15 @@ public sealed class TagSystem : EntitySystem
/// </exception> /// </exception>
public bool HasAnyTag(TagComponent comp, List<ProtoId<TagPrototype>> ids) public bool HasAnyTag(TagComponent comp, List<ProtoId<TagPrototype>> ids)
{ {
var stringIds = new List<string>(); foreach (var id in ids)
foreach (var tag in ids)
{ {
stringIds.Add(tag.Id); AssertValidTag(id);
if (comp.Tags.Contains(id))
return true;
} }
return HasAnyTag(comp, stringIds); return false;
} }
/// <summary> /// <summary>

View File

@@ -38,8 +38,8 @@ public sealed partial class EntityWhitelist
[DataField] [DataField]
public List<ProtoId<ItemSizePrototype>>? Sizes; public List<ProtoId<ItemSizePrototype>>? Sizes;
[NonSerialized] [NonSerialized, Access(typeof(EntityWhitelistSystem))]
private List<ComponentRegistration>? _registrations; public List<ComponentRegistration>? Registrations;
/// <summary> /// <summary>
/// Tags that are allowed in the whitelist. /// Tags that are allowed in the whitelist.
@@ -55,67 +55,13 @@ public sealed partial class EntityWhitelist
[DataField] [DataField]
public bool RequireAll; public bool RequireAll;
public void UpdateRegistrations() [Obsolete("Use WhitelistSystem")]
public bool IsValid(EntityUid uid, IEntityManager? man = null)
{ {
var sys = man?.System<EntityWhitelistSystem>() ??
IoCManager.Resolve<IEntitySystemManager>().GetEntitySystem<EntityWhitelistSystem>();
if (Components == null) return sys.IsValid(this, uid);
return;
var compFact = IoCManager.Resolve<IComponentFactory>();
_registrations = new List<ComponentRegistration>();
foreach (var name in Components)
{
var availability = compFact.GetComponentAvailability(name);
if (compFact.TryGetRegistration(name, out var registration)
&& availability == ComponentAvailability.Available)
{
_registrations.Add(registration);
}
else if (availability == ComponentAvailability.Unknown)
{
Logger.Warning($"Unknown component name {name} passed to EntityWhitelist!");
}
}
}
/// <summary>
/// Returns whether a given entity fits the whitelist.
/// </summary>
public bool IsValid(EntityUid uid, IEntityManager? entityManager = null)
{
if (Components != null && _registrations == null)
UpdateRegistrations();
IoCManager.Resolve(ref entityManager);
if (_registrations != null)
{
foreach (var reg in _registrations)
{
if (entityManager.HasComponent(uid, reg.Type))
{
if (!RequireAll)
return true;
}
else if (RequireAll)
return false;
}
}
if (Sizes != null && entityManager.TryGetComponent(uid, out ItemComponent? itemComp))
{
if (Sizes.Contains(itemComp.Size))
return true;
}
if (Tags != null && entityManager.TryGetComponent(uid, out TagComponent? tags))
{
var tagSystem = entityManager.System<TagSystem>();
return RequireAll ? tagSystem.HasAllTags(tags, Tags) : tagSystem.HasAnyTag(tags, Tags);
}
if (RequireAll)
return true;
return false;
} }
} }

View File

@@ -0,0 +1,84 @@
using System.Diagnostics.CodeAnalysis;
using Content.Shared.Item;
using Content.Shared.Tag;
namespace Content.Shared.Whitelist;
public sealed class EntityWhitelistSystem : EntitySystem
{
[Dependency] private readonly IComponentFactory _factory = default!;
[Dependency] private readonly TagSystem _tag = default!;
private EntityQuery<ItemComponent> _itemQuery;
public override void Initialize()
{
base.Initialize();
_itemQuery = GetEntityQuery<ItemComponent>();
}
/// <inheritdoc cref="IsValid(Content.Shared.Whitelist.EntityWhitelist,Robust.Shared.GameObjects.EntityUid)"/>
public bool IsValid(EntityWhitelist list, [NotNullWhen(true)] EntityUid? uid)
{
return uid != null && IsValid(list, uid.Value);
}
/// <summary>
/// Checks whether a given entity satisfies a whitelist.
/// </summary>
public bool IsValid(EntityWhitelist list, EntityUid uid)
{
if (list.Components != null)
EnsureRegistrations(list);
if (list.Registrations != null)
{
foreach (var reg in list.Registrations)
{
if (HasComp(uid, reg.Type))
{
if (!list.RequireAll)
return true;
}
else if (list.RequireAll)
return false;
}
}
if (list.Sizes != null && _itemQuery.TryComp(uid, out var itemComp))
{
if (list.Sizes.Contains(itemComp.Size))
return true;
}
if (list.Tags != null)
{
return list.RequireAll
? _tag.HasAllTags(uid, list.Tags)
: _tag.HasAnyTag(uid, list.Tags);
}
return list.RequireAll;
}
private void EnsureRegistrations(EntityWhitelist list)
{
if (list.Components == null)
return;
list.Registrations = new List<ComponentRegistration>();
foreach (var name in list.Components)
{
var availability = _factory.GetComponentAvailability(name);
if (_factory.TryGetRegistration(name, out var registration)
&& availability == ComponentAvailability.Available)
{
list.Registrations.Add(registration);
}
else if (availability == ComponentAvailability.Unknown)
{
Log.Warning($"Unknown component name {name} passed to EntityWhitelist!");
}
}
}
}