Add EntityWhitelistSystem (#27632)

* Add EntityWhitelistSystem

* Sandbox fix

* update test
This commit is contained in:
Leon Friedrich
2024-05-03 12:10:15 +12:00
committed by GitHub
parent 291ecf9643
commit f348e6aa30
4 changed files with 117 additions and 89 deletions

View File

@@ -64,7 +64,8 @@ namespace Content.IntegrationTests.Tests.Utility
var testMap = await pair.CreateTestMap();
var mapCoordinates = testMap.MapCoords;
var sEntities = server.ResolveDependency<IEntityManager>();
var sEntities = server.EntMan;
var sys = server.System<EntityWhitelistSystem>();
await server.WaitAssertion(() =>
{
@@ -80,22 +81,14 @@ namespace Content.IntegrationTests.Tests.Utility
Components = new[] { $"{ValidComponent}" },
Tags = new() { "WhitelistTestValidTag" }
};
whitelistInst.UpdateRegistrations();
Assert.That(whitelistInst, Is.Not.Null);
Assert.Multiple(() =>
{
Assert.That(whitelistInst.Components, Is.Not.Null);
Assert.That(whitelistInst.Tags, Is.Not.Null);
});
Assert.That(sys.IsValid(whitelistInst, validComponent), Is.True);
Assert.That(sys.IsValid(whitelistInst, WhitelistTestValidTag), Is.True);
Assert.Multiple(() =>
{
Assert.That(whitelistInst.IsValid(validComponent), Is.True);
Assert.That(whitelistInst.IsValid(WhitelistTestValidTag), Is.True);
Assert.That(whitelistInst.IsValid(invalidComponent), Is.False);
Assert.That(whitelistInst.IsValid(WhitelistTestInvalidTag), Is.False);
Assert.That(sys.IsValid(whitelistInst, invalidComponent), Is.False);
Assert.That(sys.IsValid(whitelistInst, WhitelistTestInvalidTag), Is.False);
});
// Test from serialized
@@ -111,11 +104,11 @@ namespace Content.IntegrationTests.Tests.Utility
Assert.Multiple(() =>
{
Assert.That(whitelistSer.IsValid(validComponent), Is.True);
Assert.That(whitelistSer.IsValid(WhitelistTestValidTag), Is.True);
Assert.That(sys.IsValid(whitelistSer, validComponent), Is.True);
Assert.That(sys.IsValid(whitelistSer, WhitelistTestValidTag), Is.True);
Assert.That(whitelistSer.IsValid(invalidComponent), Is.False);
Assert.That(whitelistSer.IsValid(WhitelistTestInvalidTag), Is.False);
Assert.That(sys.IsValid(whitelistSer, invalidComponent), Is.False);
Assert.That(sys.IsValid(whitelistSer, WhitelistTestInvalidTag), Is.False);
});
});
await pair.CleanReturnAsync();

View File

@@ -1,3 +1,4 @@
using System.Diagnostics;
using System.Linq;
using Robust.Shared.GameStates;
using Robust.Shared.Prototypes;
@@ -185,6 +186,7 @@ public sealed class TagSystem : EntitySystem
/// <summary>
/// Checks if a tag has been added to an entity.
/// </summary>
[Obsolete]
public bool HasTag(EntityUid entity, string id, EntityQuery<TagComponent> tagQuery)
{
return tagQuery.TryGetComponent(entity, out var component) &&
@@ -243,7 +245,7 @@ public sealed class TagSystem : EntitySystem
/// </exception>
public bool HasAllTags(EntityUid entity, List<ProtoId<TagPrototype>> ids)
{
return TryComp<TagComponent>(entity, out var component) &&
return _tagQuery.TryComp(entity, out var component) &&
HasAllTags(component, ids);
}
@@ -521,16 +523,18 @@ public sealed class TagSystem : EntitySystem
/// </exception>
public bool HasAllTags(TagComponent component, List<ProtoId<TagPrototype>> ids)
{
var stringIds = new List<string>();
foreach (var tag in ids)
foreach (var id in ids)
{
stringIds.Add(tag.Id);
AssertValidTag(id);
if (!component.Tags.Contains(id))
return false;
}
return HasAllTags(component, stringIds);
return true;
}
/// <summary>
/// Checks if any of the given tags have been added.
/// </summary>
@@ -552,7 +556,6 @@ public sealed class TagSystem : EntitySystem
return false;
}
/// <summary>
/// Checks if any of the given tags have been added.
/// </summary>
@@ -619,13 +622,15 @@ public sealed class TagSystem : EntitySystem
/// </exception>
public bool HasAnyTag(TagComponent comp, List<ProtoId<TagPrototype>> ids)
{
var stringIds = new List<string>();
foreach (var tag in ids)
foreach (var id in ids)
{
stringIds.Add(tag.Id);
AssertValidTag(id);
if (comp.Tags.Contains(id))
return true;
}
return HasAnyTag(comp, stringIds);
return false;
}
/// <summary>

View File

@@ -38,8 +38,8 @@ public sealed partial class EntityWhitelist
[DataField]
public List<ProtoId<ItemSizePrototype>>? Sizes;
[NonSerialized]
private List<ComponentRegistration>? _registrations;
[NonSerialized, Access(typeof(EntityWhitelistSystem))]
public List<ComponentRegistration>? Registrations;
/// <summary>
/// Tags that are allowed in the whitelist.
@@ -55,67 +55,13 @@ public sealed partial class EntityWhitelist
[DataField]
public bool RequireAll;
public void UpdateRegistrations()
[Obsolete("Use WhitelistSystem")]
public bool IsValid(EntityUid uid, IEntityManager? man = null)
{
var sys = man?.System<EntityWhitelistSystem>() ??
IoCManager.Resolve<IEntitySystemManager>().GetEntitySystem<EntityWhitelistSystem>();
if (Components == null)
return;
return sys.IsValid(this, uid);
var compFact = IoCManager.Resolve<IComponentFactory>();
_registrations = new List<ComponentRegistration>();
foreach (var name in Components)
{
var availability = compFact.GetComponentAvailability(name);
if (compFact.TryGetRegistration(name, out var registration)
&& availability == ComponentAvailability.Available)
{
_registrations.Add(registration);
}
else if (availability == ComponentAvailability.Unknown)
{
Logger.Warning($"Unknown component name {name} passed to EntityWhitelist!");
}
}
}
/// <summary>
/// Returns whether a given entity fits the whitelist.
/// </summary>
public bool IsValid(EntityUid uid, IEntityManager? entityManager = null)
{
if (Components != null && _registrations == null)
UpdateRegistrations();
IoCManager.Resolve(ref entityManager);
if (_registrations != null)
{
foreach (var reg in _registrations)
{
if (entityManager.HasComponent(uid, reg.Type))
{
if (!RequireAll)
return true;
}
else if (RequireAll)
return false;
}
}
if (Sizes != null && entityManager.TryGetComponent(uid, out ItemComponent? itemComp))
{
if (Sizes.Contains(itemComp.Size))
return true;
}
if (Tags != null && entityManager.TryGetComponent(uid, out TagComponent? tags))
{
var tagSystem = entityManager.System<TagSystem>();
return RequireAll ? tagSystem.HasAllTags(tags, Tags) : tagSystem.HasAnyTag(tags, Tags);
}
if (RequireAll)
return true;
return false;
}
}

View File

@@ -0,0 +1,84 @@
using System.Diagnostics.CodeAnalysis;
using Content.Shared.Item;
using Content.Shared.Tag;
namespace Content.Shared.Whitelist;
public sealed class EntityWhitelistSystem : EntitySystem
{
[Dependency] private readonly IComponentFactory _factory = default!;
[Dependency] private readonly TagSystem _tag = default!;
private EntityQuery<ItemComponent> _itemQuery;
public override void Initialize()
{
base.Initialize();
_itemQuery = GetEntityQuery<ItemComponent>();
}
/// <inheritdoc cref="IsValid(Content.Shared.Whitelist.EntityWhitelist,Robust.Shared.GameObjects.EntityUid)"/>
public bool IsValid(EntityWhitelist list, [NotNullWhen(true)] EntityUid? uid)
{
return uid != null && IsValid(list, uid.Value);
}
/// <summary>
/// Checks whether a given entity satisfies a whitelist.
/// </summary>
public bool IsValid(EntityWhitelist list, EntityUid uid)
{
if (list.Components != null)
EnsureRegistrations(list);
if (list.Registrations != null)
{
foreach (var reg in list.Registrations)
{
if (HasComp(uid, reg.Type))
{
if (!list.RequireAll)
return true;
}
else if (list.RequireAll)
return false;
}
}
if (list.Sizes != null && _itemQuery.TryComp(uid, out var itemComp))
{
if (list.Sizes.Contains(itemComp.Size))
return true;
}
if (list.Tags != null)
{
return list.RequireAll
? _tag.HasAllTags(uid, list.Tags)
: _tag.HasAnyTag(uid, list.Tags);
}
return list.RequireAll;
}
private void EnsureRegistrations(EntityWhitelist list)
{
if (list.Components == null)
return;
list.Registrations = new List<ComponentRegistration>();
foreach (var name in list.Components)
{
var availability = _factory.GetComponentAvailability(name);
if (_factory.TryGetRegistration(name, out var registration)
&& availability == ComponentAvailability.Available)
{
list.Registrations.Add(registration);
}
else if (availability == ComponentAvailability.Unknown)
{
Log.Warning($"Unknown component name {name} passed to EntityWhitelist!");
}
}
}
}