add support for per-id access on AccessReaderComponent (#13659)

* add support for per-id access on AccessReaderComponent

* comments!!!

* oh yeah we predicting baby

* foobar

* sloth review

* weh
This commit is contained in:
Nemanja
2023-02-28 11:03:55 -05:00
committed by GitHub
parent 38e61e1709
commit 13d71f14e2
12 changed files with 234 additions and 83 deletions

View File

@@ -17,59 +17,60 @@ namespace Content.IntegrationTests.Tests.Access
await using var pairTracker = await PoolManager.GetServerClient(new PoolSettings{NoClient = true}); await using var pairTracker = await PoolManager.GetServerClient(new PoolSettings{NoClient = true});
var server = pairTracker.Pair.Server; var server = pairTracker.Pair.Server;
await server.WaitAssertion(() => await server.WaitAssertion(() =>
{ {
var system = EntitySystem.Get<AccessReaderSystem>(); var system = EntitySystem.Get<AccessReaderSystem>();
// test empty // test empty
var reader = new AccessReaderComponent(); var reader = new AccessReaderComponent();
Assert.That(system.IsAllowed(new[] { "Foo" }, reader), Is.True); Assert.That(system.AreAccessTagsAllowed(new[] { "Foo" }, reader), Is.True);
Assert.That(system.IsAllowed(new[] { "Bar" }, reader), Is.True); Assert.That(system.AreAccessTagsAllowed(new[] { "Bar" }, reader), Is.True);
Assert.That(system.IsAllowed(new string[] { }, reader), Is.True); Assert.That(system.AreAccessTagsAllowed(new string[] { }, reader), Is.True);
// test deny // test deny
reader = new AccessReaderComponent(); reader = new AccessReaderComponent();
reader.DenyTags.Add("A"); reader.DenyTags.Add("A");
Assert.That(system.IsAllowed(new[] { "Foo" }, reader), Is.True); Assert.That(system.AreAccessTagsAllowed(new[] { "Foo" }, reader), Is.True);
Assert.That(system.IsAllowed(new[] { "A" }, reader), Is.False); Assert.That(system.AreAccessTagsAllowed(new[] { "A" }, reader), Is.False);
Assert.That(system.IsAllowed(new[] { "A", "Foo" }, reader), Is.False); Assert.That(system.AreAccessTagsAllowed(new[] { "A", "Foo" }, reader), Is.False);
Assert.That(system.IsAllowed(new string[] { }, reader), Is.True); Assert.That(system.AreAccessTagsAllowed(new string[] { }, reader), Is.True);
// test one list // test one list
reader = new AccessReaderComponent(); reader = new AccessReaderComponent();
reader.AccessLists.Add(new HashSet<string> { "A" }); reader.AccessLists.Add(new HashSet<string> { "A" });
Assert.That(system.IsAllowed(new[] { "A" }, reader), Is.True); Assert.That(system.AreAccessTagsAllowed(new[] { "A" }, reader), Is.True);
Assert.That(system.IsAllowed(new[] { "B" }, reader), Is.False); Assert.That(system.AreAccessTagsAllowed(new[] { "B" }, reader), Is.False);
Assert.That(system.IsAllowed(new[] { "A", "B" }, reader), Is.True); Assert.That(system.AreAccessTagsAllowed(new[] { "A", "B" }, reader), Is.True);
Assert.That(system.IsAllowed(new string[] { }, reader), Is.False); Assert.That(system.AreAccessTagsAllowed(new string[] { }, reader), Is.False);
// test one list - two items // test one list - two items
reader = new AccessReaderComponent(); reader = new AccessReaderComponent();
reader.AccessLists.Add(new HashSet<string> { "A", "B" }); reader.AccessLists.Add(new HashSet<string> { "A", "B" });
Assert.That(system.IsAllowed(new[] { "A" }, reader), Is.False); Assert.That(system.AreAccessTagsAllowed(new[] { "A" }, reader), Is.False);
Assert.That(system.IsAllowed(new[] { "B" }, reader), Is.False); Assert.That(system.AreAccessTagsAllowed(new[] { "B" }, reader), Is.False);
Assert.That(system.IsAllowed(new[] { "A", "B" }, reader), Is.True); Assert.That(system.AreAccessTagsAllowed(new[] { "A", "B" }, reader), Is.True);
Assert.That(system.IsAllowed(new string[] { }, reader), Is.False); Assert.That(system.AreAccessTagsAllowed(new string[] { }, reader), Is.False);
// test two list // test two list
reader = new AccessReaderComponent(); reader = new AccessReaderComponent();
reader.AccessLists.Add(new HashSet<string> { "A" }); reader.AccessLists.Add(new HashSet<string> { "A" });
reader.AccessLists.Add(new HashSet<string> { "B", "C" }); reader.AccessLists.Add(new HashSet<string> { "B", "C" });
Assert.That(system.IsAllowed(new[] { "A" }, reader), Is.True); Assert.That(system.AreAccessTagsAllowed(new[] { "A" }, reader), Is.True);
Assert.That(system.IsAllowed(new[] { "B" }, reader), Is.False); Assert.That(system.AreAccessTagsAllowed(new[] { "B" }, reader), Is.False);
Assert.That(system.IsAllowed(new[] { "A", "B" }, reader), Is.True); Assert.That(system.AreAccessTagsAllowed(new[] { "A", "B" }, reader), Is.True);
Assert.That(system.IsAllowed(new[] { "C", "B" }, reader), Is.True); Assert.That(system.AreAccessTagsAllowed(new[] { "C", "B" }, reader), Is.True);
Assert.That(system.IsAllowed(new[] { "C", "B", "A" }, reader), Is.True); Assert.That(system.AreAccessTagsAllowed(new[] { "C", "B", "A" }, reader), Is.True);
Assert.That(system.IsAllowed(new string[] { }, reader), Is.False); Assert.That(system.AreAccessTagsAllowed(new string[] { }, reader), Is.False);
// test deny list // test deny list
reader = new AccessReaderComponent(); reader = new AccessReaderComponent();
reader.AccessLists.Add(new HashSet<string> { "A" }); reader.AccessLists.Add(new HashSet<string> { "A" });
reader.DenyTags.Add("B"); reader.DenyTags.Add("B");
Assert.That(system.IsAllowed(new[] { "A" }, reader), Is.True); Assert.That(system.AreAccessTagsAllowed(new[] { "A" }, reader), Is.True);
Assert.That(system.IsAllowed(new[] { "B" }, reader), Is.False); Assert.That(system.AreAccessTagsAllowed(new[] { "B" }, reader), Is.False);
Assert.That(system.IsAllowed(new[] { "A", "B" }, reader), Is.False); Assert.That(system.AreAccessTagsAllowed(new[] { "A", "B" }, reader), Is.False);
Assert.That(system.IsAllowed(new string[] { }, reader), Is.False); Assert.That(system.AreAccessTagsAllowed(new string[] { }, reader), Is.False);
}); });
await pairTracker.CleanReturnAsync(); await pairTracker.CleanReturnAsync();
} }

View File

@@ -1,6 +1,6 @@
using System.Linq; using System.Linq;
using Content.Server.Station.Systems; using Content.Server.Station.Systems;
using Content.Server.StationRecords; using Content.Server.StationRecords.Systems;
using Content.Shared.Access.Components; using Content.Shared.Access.Components;
using Content.Shared.Access.Systems; using Content.Shared.Access.Systems;
using Content.Shared.Administration.Logs; using Content.Shared.Administration.Logs;

View File

@@ -1,22 +1,17 @@
using System.Linq; using System.Linq;
using Content.Server.Administration; using Content.Server.Administration;
using Content.Server.EUI; using Content.Server.EUI;
using Content.Server.GameTicking;
using Content.Server.Station.Systems; using Content.Server.Station.Systems;
using Content.Server.StationRecords; using Content.Server.StationRecords;
using Content.Server.StationRecords.Systems;
using Content.Shared.Administration; using Content.Shared.Administration;
using Content.Shared.CCVar; using Content.Shared.CCVar;
using Content.Shared.CrewManifest; using Content.Shared.CrewManifest;
using Content.Shared.GameTicking; using Content.Shared.GameTicking;
using Content.Shared.Roles;
using Content.Shared.StationRecords; using Content.Shared.StationRecords;
using Robust.Server.GameObjects;
using Robust.Server.Player; using Robust.Server.Player;
using Robust.Shared.Configuration; using Robust.Shared.Configuration;
using Robust.Shared.Console; using Robust.Shared.Console;
using Robust.Shared.Player;
using Robust.Shared.Players;
using Robust.Shared.Prototypes;
namespace Content.Server.CrewManifest; namespace Content.Server.CrewManifest;

View File

@@ -1,10 +1,9 @@
using Content.Server.Access.Systems; using Content.Server.Access.Systems;
using Content.Server.Administration; using Content.Server.Administration;
using Content.Server.Administration.Systems; using Content.Server.Administration.Systems;
using Content.Server.Cloning;
using Content.Server.Mind.Components; using Content.Server.Mind.Components;
using Content.Server.PDA; using Content.Server.PDA;
using Content.Server.StationRecords; using Content.Server.StationRecords.Systems;
using Content.Shared.Access.Components; using Content.Shared.Access.Components;
using Content.Shared.Administration; using Content.Shared.Administration;
using Content.Shared.PDA; using Content.Shared.PDA;

View File

@@ -1,13 +0,0 @@
using Content.Shared.StationRecords;
namespace Content.Server.StationRecords;
[RegisterComponent]
public sealed class StationRecordKeyStorageComponent : Component
{
/// <summary>
/// The key stored in this component.
/// </summary>
[ViewVariables]
public StationRecordKey? Key;
}

View File

@@ -1,3 +1,5 @@
using Content.Server.StationRecords.Systems;
namespace Content.Server.StationRecords; namespace Content.Server.StationRecords;
[Access(typeof(StationRecordsSystem))] [Access(typeof(StationRecordsSystem))]

View File

@@ -1,9 +1,7 @@
using System.Diagnostics.CodeAnalysis; using System.Diagnostics.CodeAnalysis;
using Content.Server.Access.Systems; using System.Linq;
using Content.Server.GameTicking; using Content.Server.GameTicking;
using Content.Server.Station.Systems; using Content.Server.Station.Systems;
using Content.Server.StationRecords;
using Content.Server.StationRecords.Systems;
using Content.Shared.Access.Components; using Content.Shared.Access.Components;
using Content.Shared.Inventory; using Content.Shared.Inventory;
using Content.Shared.PDA; using Content.Shared.PDA;
@@ -13,6 +11,8 @@ using Content.Shared.StationRecords;
using Robust.Shared.Enums; using Robust.Shared.Enums;
using Robust.Shared.Prototypes; using Robust.Shared.Prototypes;
namespace Content.Server.StationRecords.Systems;
/// <summary> /// <summary>
/// Station records. /// Station records.
/// ///

View File

@@ -1,22 +1,52 @@
namespace Content.Shared.Access.Components using Content.Shared.Access.Systems;
using Content.Shared.StationRecords;
using Robust.Shared.GameStates;
using Robust.Shared.Serialization;
using Robust.Shared.Serialization.TypeSerializers.Implementations.Custom.Prototype.Set;
namespace Content.Shared.Access.Components;
/// <summary>
/// Stores access levels necessary to "use" an entity
/// and allows checking if something or somebody is authorized with these access levels.
/// </summary>
[RegisterComponent, NetworkedComponent]
public sealed class AccessReaderComponent : Component
{ {
/// <summary> /// <summary>
/// Stores access levels necessary to "use" an entity /// The set of tags that will automatically deny an allowed check, if any of them are present.
/// and allows checking if something or somebody is authorized with these access levels.
/// </summary> /// </summary>
[RegisterComponent] [DataField("denyTags", customTypeSerializer: typeof(PrototypeIdHashSetSerializer<AccessLevelPrototype>))]
public sealed class AccessReaderComponent : Component public HashSet<string> DenyTags = new();
{
/// <summary>
/// The set of tags that will automatically deny an allowed check, if any of them are present.
/// </summary>
public HashSet<string> DenyTags = new();
/// <summary> /// <summary>
/// List of access lists to check allowed against. For an access check to pass /// List of access lists to check allowed against. For an access check to pass
/// there has to be an access list that is a subset of the access in the checking list. /// there has to be an access list that is a subset of the access in the checking list.
/// </summary> /// </summary>
[DataField("access")] [DataField("access")]
public List<HashSet<string>> AccessLists = new(); public List<HashSet<string>> AccessLists = new();
/// <summary>
/// A list of valid stationrecordkeys
/// </summary>
[DataField("accessKeys")]
public HashSet<StationRecordKey> AccessKeys = new();
}
[Serializable, NetSerializable]
public sealed class AccessReaderComponentState : ComponentState
{
public HashSet<string> DenyTags;
public List<HashSet<string>> AccessLists;
public HashSet<StationRecordKey> AccessKeys;
public AccessReaderComponentState(HashSet<string> denyTags, List<HashSet<string>> accessLists, HashSet<StationRecordKey> accessKeys)
{
DenyTags = denyTags;
AccessLists = accessLists;
AccessKeys = accessKeys;
} }
} }

View File

@@ -8,6 +8,8 @@ using Content.Shared.Access.Components;
using Robust.Shared.Prototypes; using Robust.Shared.Prototypes;
using Content.Shared.Hands.EntitySystems; using Content.Shared.Hands.EntitySystems;
using Content.Shared.MachineLinking.Events; using Content.Shared.MachineLinking.Events;
using Content.Shared.StationRecords;
using Robust.Shared.GameStates;
namespace Content.Shared.Access.Systems namespace Content.Shared.Access.Systems
{ {
@@ -23,6 +25,24 @@ namespace Content.Shared.Access.Systems
SubscribeLocalEvent<AccessReaderComponent, ComponentInit>(OnInit); SubscribeLocalEvent<AccessReaderComponent, ComponentInit>(OnInit);
SubscribeLocalEvent<AccessReaderComponent, GotEmaggedEvent>(OnEmagged); SubscribeLocalEvent<AccessReaderComponent, GotEmaggedEvent>(OnEmagged);
SubscribeLocalEvent<AccessReaderComponent, LinkAttemptEvent>(OnLinkAttempt); SubscribeLocalEvent<AccessReaderComponent, LinkAttemptEvent>(OnLinkAttempt);
SubscribeLocalEvent<AccessReaderComponent, ComponentGetState>(OnGetState);
SubscribeLocalEvent<AccessReaderComponent, ComponentHandleState>(OnHandleState);
}
private void OnGetState(EntityUid uid, AccessReaderComponent component, ref ComponentGetState args)
{
args.State = new AccessReaderComponentState(component.DenyTags, component.AccessLists,
component.AccessKeys);
}
private void OnHandleState(EntityUid uid, AccessReaderComponent component, ref ComponentHandleState args)
{
if (args.Current is not AccessReaderComponentState state)
return;
component.AccessKeys = new (state.AccessKeys);
component.AccessLists = new (state.AccessLists);
component.DenyTags = new (state.DenyTags);
} }
private void OnLinkAttempt(EntityUid uid, AccessReaderComponent component, LinkAttemptEvent args) private void OnLinkAttempt(EntityUid uid, AccessReaderComponent component, LinkAttemptEvent args)
@@ -62,8 +82,7 @@ namespace Content.Shared.Access.Systems
{ {
if (!Resolve(target, ref reader, false)) if (!Resolve(target, ref reader, false))
return true; return true;
var tags = FindAccessTags(source); return IsAllowed(source, reader);
return IsAllowed(tags, reader);
} }
/// <summary> /// <summary>
@@ -74,8 +93,15 @@ namespace Content.Shared.Access.Systems
/// <param name="reader">A reader from a different entity</param> /// <param name="reader">A reader from a different entity</param>
public bool IsAllowed(EntityUid entity, AccessReaderComponent reader) public bool IsAllowed(EntityUid entity, AccessReaderComponent reader)
{ {
var tags = FindAccessTags(entity); var allEnts = FindPotentialAccessItems(entity);
return IsAllowed(tags, reader);
if (AreAccessTagsAllowed(FindAccessTags(entity, allEnts), reader))
return true;
if (AreStationRecordKeysAllowed(FindStationRecordKeys(entity, allEnts), reader))
return true;
return false;
} }
/// <summary> /// <summary>
@@ -83,7 +109,7 @@ namespace Content.Shared.Access.Systems
/// </summary> /// </summary>
/// <param name="accessTags">A list of access tags</param> /// <param name="accessTags">A list of access tags</param>
/// <param name="reader">An access reader to check against</param> /// <param name="reader">An access reader to check against</param>
public bool IsAllowed(ICollection<string> accessTags, AccessReaderComponent reader) public bool AreAccessTagsAllowed(ICollection<string> accessTags, AccessReaderComponent reader)
{ {
if (HasComp<EmaggedComponent>(reader.Owner)) if (HasComp<EmaggedComponent>(reader.Owner))
{ {
@@ -105,17 +131,18 @@ namespace Content.Shared.Access.Systems
} }
/// <summary> /// <summary>
/// Finds the access tags on the given entity /// Compares the given stationrecordkeys with the accessreader to see if it is allowed.
/// </summary> /// </summary>
/// <param name="uid">The entity that is being searched.</param> public bool AreStationRecordKeysAllowed(ICollection<StationRecordKey> keys, AccessReaderComponent reader)
public ICollection<string> FindAccessTags(EntityUid uid)
{ {
HashSet<string>? tags = null; return keys.Any() && reader.AccessKeys.Any(keys.Contains);
var owned = false; }
// check entity itself
FindAccessTagsItem(uid, ref tags, ref owned);
/// <summary>
/// Finds all the items that could potentially give access to a given entity
/// </summary>
public HashSet<EntityUid> FindPotentialAccessItems(EntityUid uid)
{
FindAccessItemsInventory(uid, out var items); FindAccessItemsInventory(uid, out var items);
var ev = new GetAdditionalAccessEvent var ev = new GetAdditionalAccessEvent
@@ -123,7 +150,23 @@ namespace Content.Shared.Access.Systems
Entities = items Entities = items
}; };
RaiseLocalEvent(uid, ref ev); RaiseLocalEvent(uid, ref ev);
foreach (var ent in ev.Entities) items.Add(uid);
return items;
}
/// <summary>
/// Finds the access tags on the given entity
/// </summary>
/// <param name="uid">The entity that is being searched.</param>
/// <param name="items">All of the items to search for access. If none are passed in, <see cref="FindPotentialAccessItems"/> will be used.</param>
public ICollection<string> FindAccessTags(EntityUid uid, HashSet<EntityUid>? items = null)
{
HashSet<string>? tags = null;
var owned = false;
items ??= FindPotentialAccessItems(uid);
foreach (var ent in items)
{ {
FindAccessTagsItem(ent, ref tags, ref owned); FindAccessTagsItem(ent, ref tags, ref owned);
} }
@@ -131,6 +174,26 @@ namespace Content.Shared.Access.Systems
return (ICollection<string>?) tags ?? Array.Empty<string>(); return (ICollection<string>?) tags ?? Array.Empty<string>();
} }
/// <summary>
/// Finds the access tags on the given entity
/// </summary>
/// <param name="uid">The entity that is being searched.</param>
/// <param name="items">All of the items to search for access. If none are passed in, <see cref="FindPotentialAccessItems"/> will be used.</param>
public ICollection<StationRecordKey> FindStationRecordKeys(EntityUid uid, HashSet<EntityUid>? items = null)
{
HashSet<StationRecordKey> keys = new();
items ??= FindPotentialAccessItems(uid);
foreach (var ent in items)
{
if (FindStationRecordKeyItem(ent, out var key))
keys.Add(key.Value);
}
return keys;
}
/// <summary> /// <summary>
/// Try to find <see cref="AccessComponent"/> on this item /// Try to find <see cref="AccessComponent"/> on this item
/// or inside this item (if it's pda) /// or inside this item (if it's pda)
@@ -203,5 +266,31 @@ namespace Content.Shared.Access.Systems
tags = null; tags = null;
return false; return false;
} }
/// <summary>
/// Try to find <see cref="StationRecordKeyStorageComponent"/> on this item
/// or inside this item (if it's pda)
/// </summary>
private bool FindStationRecordKeyItem(EntityUid uid, [NotNullWhen(true)] out StationRecordKey? key)
{
if (EntityManager.TryGetComponent(uid, out StationRecordKeyStorageComponent? storage) && storage.Key != null)
{
key = storage.Key;
return true;
}
if (TryComp<PDAComponent>(uid, out var pda) &&
pda.ContainedID?.Owner is {Valid: true} id)
{
if (TryComp<StationRecordKeyStorageComponent>(id, out var pdastorage) && pdastorage.Key != null)
{
key = pdastorage.Key;
return true;
}
}
key = null;
return false;
}
} }
} }

View File

@@ -130,7 +130,8 @@ namespace Content.Shared.Emag.Systems
if (component.Charges <= 0) if (component.Charges <= 0)
{ {
_popupSystem.PopupEntity(Loc.GetString("emag-no-charges"), user, user); if (_net.IsServer)
_popupSystem.PopupEntity(Loc.GetString("emag-no-charges"), user, user);
return false; return false;
} }

View File

@@ -0,0 +1,25 @@
using Robust.Shared.GameStates;
using Robust.Shared.Serialization;
namespace Content.Shared.StationRecords;
[RegisterComponent, NetworkedComponent]
public sealed class StationRecordKeyStorageComponent : Component
{
/// <summary>
/// The key stored in this component.
/// </summary>
[ViewVariables]
public StationRecordKey? Key;
}
[Serializable, NetSerializable]
public sealed class StationRecordKeyStorageComponentState : ComponentState
{
public StationRecordKey? Key;
public StationRecordKeyStorageComponentState(StationRecordKey? key)
{
Key = key;
}
}

View File

@@ -1,9 +1,29 @@
using Content.Shared.StationRecords; using Robust.Shared.GameStates;
namespace Content.Server.StationRecords.Systems; namespace Content.Shared.StationRecords;
public sealed class StationRecordKeyStorageSystem : EntitySystem public sealed class StationRecordKeyStorageSystem : EntitySystem
{ {
public override void Initialize()
{
base.Initialize();
SubscribeLocalEvent<StationRecordKeyStorageComponent, ComponentGetState>(OnGetState);
SubscribeLocalEvent<StationRecordKeyStorageComponent, ComponentHandleState>(OnHandleState);
}
private void OnGetState(EntityUid uid, StationRecordKeyStorageComponent component, ref ComponentGetState args)
{
args.State = new StationRecordKeyStorageComponentState(component.Key);
}
private void OnHandleState(EntityUid uid, StationRecordKeyStorageComponent component, ref ComponentHandleState args)
{
if (args.Current is not StationRecordKeyStorageComponentState state)
return;
component.Key = state.Key;
}
/// <summary> /// <summary>
/// Assigns a station record key to an entity. /// Assigns a station record key to an entity.
/// </summary> /// </summary>
@@ -18,6 +38,7 @@ public sealed class StationRecordKeyStorageSystem : EntitySystem
} }
keyStorage.Key = key; keyStorage.Key = key;
Dirty(keyStorage);
} }
/// <summary> /// <summary>
@@ -35,6 +56,7 @@ public sealed class StationRecordKeyStorageSystem : EntitySystem
var key = keyStorage.Key; var key = keyStorage.Key;
keyStorage.Key = null; keyStorage.Key = null;
Dirty(keyStorage);
return key; return key;
} }