using System; using System.Linq; using Microsoft.EntityFrameworkCore; using Microsoft.EntityFrameworkCore.Diagnostics; using Microsoft.EntityFrameworkCore.Infrastructure; using Microsoft.EntityFrameworkCore.Storage; using Npgsql; namespace Content.Server.Database { public sealed class PostgresServerDbContext : ServerDbContext { public PostgresServerDbContext(DbContextOptions options) : base(options) { } protected override void OnConfiguring(DbContextOptionsBuilder options) { ((IDbContextOptionsBuilderInfrastructure) options).AddOrUpdateExtension(new SnakeCaseExtension()); options.ConfigureWarnings(x => { x.Ignore(CoreEventId.ManyServiceProvidersCreatedWarning); #if DEBUG // for tests x.Ignore(CoreEventId.SensitiveDataLoggingEnabledWarning); #endif }); #if DEBUG options.EnableSensitiveDataLogging(); #endif } protected override void OnModelCreating(ModelBuilder modelBuilder) { base.OnModelCreating(modelBuilder); // ReSharper disable StringLiteralTypo // Enforce that an address cannot be IPv6-mapped IPv4. // So that IPv4 addresses are consistent between separate-socket and dual-stack socket modes. modelBuilder.Entity().ToTable(t => t.HasCheckConstraint("AddressNotIPv6MappedIPv4", "NOT inet '::ffff:0.0.0.0/96' >>= address")); modelBuilder.Entity().ToTable( t => t.HasCheckConstraint("AddressNotIPv6MappedIPv4", "NOT inet '::ffff:0.0.0.0/96' >>= address")); modelBuilder.Entity().ToTable(t => t.HasCheckConstraint("LastSeenAddressNotIPv6MappedIPv4", "NOT inet '::ffff:0.0.0.0/96' >>= last_seen_address")); modelBuilder.Entity().ToTable(t => t.HasCheckConstraint("AddressNotIPv6MappedIPv4", "NOT inet '::ffff:0.0.0.0/96' >>= address")); // ReSharper restore StringLiteralTypo modelBuilder.Entity() .HasIndex(l => l.Message) .HasMethod("GIN") .IsTsVectorExpressionIndex("english"); foreach(var entity in modelBuilder.Model.GetEntityTypes()) { foreach(var property in entity.GetProperties()) { if (property.FieldInfo?.FieldType == typeof(DateTime) || property.FieldInfo?.FieldType == typeof(DateTime?)) property.SetColumnType("timestamp with time zone"); } } } public override IQueryable SearchLogs(IQueryable query, string searchText) { return query.Where(log => EF.Functions.ToTsVector("english", log.Message).Matches(searchText)); } public override int CountAdminLogs() { using var command = new NpgsqlCommand("SELECT reltuples FROM pg_class WHERE relname = 'admin_log';", (NpgsqlConnection?) Database.GetDbConnection()); Database.GetDbConnection().Open(); var count = Convert.ToInt32((float) (command.ExecuteScalar() ?? 0)); Database.GetDbConnection().Close(); return count; } } }