From d4fa8f09546270d346d6085572c0511b56217a49 Mon Sep 17 00:00:00 2001 From: audinowho <2676737+audinowho@users.noreply.github.com> Date: Tue, 27 May 2025 12:37:54 -0800 Subject: [PATCH] Support `--include-threads` in the `export` command (#1343) --- .../Commands/Base/ExportCommandBase.cs | 60 +++++++++++++++++-- .../Commands/ExportAllCommand.cs | 51 ---------------- .../Commands/ExportChannelsCommand.cs | 43 ++++++++++++- .../Commands/ExportGuildCommand.cs | 41 ------------- .../Discord/DiscordClient.cs | 40 +++++++++++-- 5 files changed, 131 insertions(+), 104 deletions(-) diff --git a/DiscordChatExporter.Cli/Commands/Base/ExportCommandBase.cs b/DiscordChatExporter.Cli/Commands/Base/ExportCommandBase.cs index d77cd74d..e386c7f4 100644 --- a/DiscordChatExporter.Cli/Commands/Base/ExportCommandBase.cs +++ b/DiscordChatExporter.Cli/Commands/Base/ExportCommandBase.cs @@ -8,6 +8,7 @@ using CliFx.Attributes; using CliFx.Exceptions; using CliFx.Infrastructure; using DiscordChatExporter.Cli.Commands.Converters; +using DiscordChatExporter.Cli.Commands.Shared; using DiscordChatExporter.Cli.Utils.Extensions; using DiscordChatExporter.Core.Discord; using DiscordChatExporter.Core.Discord.Data; @@ -64,6 +65,13 @@ public abstract class ExportCommandBase : DiscordCommandBase )] public PartitionLimit PartitionLimit { get; init; } = PartitionLimit.Null; + [CommandOption( + "include-threads", + Description = "Which types of threads should be included.", + Converter = typeof(ThreadInclusionModeBindingConverter) + )] + public ThreadInclusionMode ThreadInclusionMode { get; init; } = ThreadInclusionMode.None; + [CommandOption( "filter", Description = "Only include messages that satisfy this filter. " @@ -141,6 +149,47 @@ public abstract class ExportCommandBase : DiscordCommandBase protected async ValueTask ExportAsync(IConsole console, IReadOnlyList channels) { + var cancellationToken = console.RegisterCancellationHandler(); + + var unwrappedChannels = new List(); + unwrappedChannels.AddRange(channels); + // Threads + if (ThreadInclusionMode != ThreadInclusionMode.None) + { + await console.Output.WriteLineAsync("Fetching threads..."); + + var fetchedThreadsCount = 0; + await console + .CreateStatusTicker() + .StartAsync( + "...", + async ctx => + { + await foreach ( + var thread in Discord.GetChannelThreadsAsync( + unwrappedChannels, + ThreadInclusionMode == ThreadInclusionMode.All, + Before, + After, + cancellationToken + ) + ) + { + unwrappedChannels.Add(thread); + + ctx.Status(Markup.Escape($"Fetched '{thread.GetHierarchicalName()}'.")); + + fetchedThreadsCount++; + } + } + ); + + // Remove unneeded forums, as they cannot be crawled directly. + unwrappedChannels.RemoveAll(channel => channel.Kind == ChannelKind.GuildForum); + + await console.Output.WriteLineAsync($"Fetched {fetchedThreadsCount} thread(s)."); + } + // Asset reuse can only be enabled if the download assets option is set // https://github.com/Tyrrrz/DiscordChatExporter/issues/425 if (ShouldReuseAssets && !ShouldDownloadAssets) @@ -160,7 +209,7 @@ public abstract class ExportCommandBase : DiscordCommandBase // https://github.com/Tyrrrz/DiscordChatExporter/issues/917 var isValidOutputPath = // Anything is valid when exporting a single channel - channels.Count <= 1 + unwrappedChannels.Count <= 1 // When using template tokens, assume the user knows what they're doing || OutputPath.Contains('%') // Otherwise, require an existing directory or an unambiguous directory path @@ -177,11 +226,10 @@ public abstract class ExportCommandBase : DiscordCommandBase } // Export - var cancellationToken = console.RegisterCancellationHandler(); var errorsByChannel = new ConcurrentDictionary(); var warningsByChannel = new ConcurrentDictionary(); - await console.Output.WriteLineAsync($"Exporting {channels.Count} channel(s)..."); + await console.Output.WriteLineAsync($"Exporting {unwrappedChannels.Count} channel(s)..."); await console .CreateProgressTicker() .HideCompleted( @@ -193,7 +241,7 @@ public abstract class ExportCommandBase : DiscordCommandBase .StartAsync(async ctx => { await Parallel.ForEachAsync( - channels, + unwrappedChannels, new ParallelOptions { MaxDegreeOfParallelism = Math.Max(1, ParallelLimit), @@ -253,7 +301,7 @@ public abstract class ExportCommandBase : DiscordCommandBase using (console.WithForegroundColor(ConsoleColor.White)) { await console.Output.WriteLineAsync( - $"Successfully exported {channels.Count - errorsByChannel.Count} channel(s)." + $"Successfully exported {unwrappedChannels.Count - errorsByChannel.Count} channel(s)." ); } @@ -301,7 +349,7 @@ public abstract class ExportCommandBase : DiscordCommandBase // Fail the command only if ALL channels failed to export. // If only some channels failed to export, it's okay. - if (errorsByChannel.Count >= channels.Count) + if (errorsByChannel.Count >= unwrappedChannels.Count) throw new CommandException("Export failed."); } diff --git a/DiscordChatExporter.Cli/Commands/ExportAllCommand.cs b/DiscordChatExporter.Cli/Commands/ExportAllCommand.cs index 1ec5e289..293fea7e 100644 --- a/DiscordChatExporter.Cli/Commands/ExportAllCommand.cs +++ b/DiscordChatExporter.Cli/Commands/ExportAllCommand.cs @@ -27,13 +27,6 @@ public class ExportAllCommand : ExportCommandBase [CommandOption("include-vc", Description = "Include voice channels.")] public bool IncludeVoiceChannels { get; init; } = true; - [CommandOption( - "include-threads", - Description = "Which types of threads should be included.", - Converter = typeof(ThreadInclusionModeBindingConverter) - )] - public ThreadInclusionMode ThreadInclusionMode { get; init; } = ThreadInclusionMode.None; - [CommandOption( "data-package", Description = "Path to the personal data package (ZIP file) requested from Discord. " @@ -90,46 +83,6 @@ public class ExportAllCommand : ExportCommandBase ); await console.Output.WriteLineAsync($"Fetched {fetchedChannelsCount} channel(s)."); - - // Threads - if (ThreadInclusionMode != ThreadInclusionMode.None) - { - await console.Output.WriteLineAsync( - $"Fetching threads for server '{guild.Name}'..." - ); - - var fetchedThreadsCount = 0; - await console - .CreateStatusTicker() - .StartAsync( - "...", - async ctx => - { - await foreach ( - var thread in Discord.GetGuildThreadsAsync( - guild.Id, - ThreadInclusionMode == ThreadInclusionMode.All, - Before, - After, - cancellationToken - ) - ) - { - channels.Add(thread); - - ctx.Status( - Markup.Escape($"Fetched '{thread.GetHierarchicalName()}'.") - ); - - fetchedThreadsCount++; - } - } - ); - - await console.Output.WriteLineAsync( - $"Fetched {fetchedThreadsCount} thread(s)." - ); - } } } // Pull from the data package @@ -199,10 +152,6 @@ public class ExportAllCommand : ExportCommandBase channels.RemoveAll(c => c.IsGuild); if (!IncludeVoiceChannels) channels.RemoveAll(c => c.IsVoice); - if (ThreadInclusionMode == ThreadInclusionMode.None) - channels.RemoveAll(c => c.IsThread); - if (ThreadInclusionMode != ThreadInclusionMode.All) - channels.RemoveAll(c => c is { IsThread: true, IsArchived: true }); await ExportAsync(console, channels); } diff --git a/DiscordChatExporter.Cli/Commands/ExportChannelsCommand.cs b/DiscordChatExporter.Cli/Commands/ExportChannelsCommand.cs index 79566c06..982e587e 100644 --- a/DiscordChatExporter.Cli/Commands/ExportChannelsCommand.cs +++ b/DiscordChatExporter.Cli/Commands/ExportChannelsCommand.cs @@ -1,9 +1,16 @@ using System.Collections.Generic; +using System.Linq; using System.Threading.Tasks; using CliFx.Attributes; using CliFx.Infrastructure; using DiscordChatExporter.Cli.Commands.Base; +using DiscordChatExporter.Cli.Commands.Converters; +using DiscordChatExporter.Cli.Commands.Shared; +using DiscordChatExporter.Cli.Utils.Extensions; using DiscordChatExporter.Core.Discord; +using DiscordChatExporter.Core.Discord.Data; +using DiscordChatExporter.Core.Utils.Extensions; +using Spectre.Console; namespace DiscordChatExporter.Cli.Commands; @@ -22,6 +29,40 @@ public class ExportChannelsCommand : ExportCommandBase public override async ValueTask ExecuteAsync(IConsole console) { await base.ExecuteAsync(console); - await ExportAsync(console, ChannelIds); + + var cancellationToken = console.RegisterCancellationHandler(); + + await console.Output.WriteLineAsync("Resolving channel(s)..."); + + var channels = new List(); + var channelsByGuild = new Dictionary>(); + + foreach (var channelId in ChannelIds) + { + var channel = await Discord.GetChannelAsync(channelId, cancellationToken); + + // Unwrap categories + if (channel.IsCategory) + { + var guildChannels = + channelsByGuild.GetValueOrDefault(channel.GuildId) + ?? await Discord.GetGuildChannelsAsync(channel.GuildId, cancellationToken); + + foreach (var guildChannel in guildChannels) + { + if (guildChannel.Parent?.Id == channel.Id) + channels.Add(guildChannel); + } + + // Cache the guild channels to avoid redundant work + channelsByGuild[channel.GuildId] = guildChannels; + } + else + { + channels.Add(channel); + } + } + + await ExportAsync(console, channels); } } diff --git a/DiscordChatExporter.Cli/Commands/ExportGuildCommand.cs b/DiscordChatExporter.Cli/Commands/ExportGuildCommand.cs index 71b66467..64f18395 100644 --- a/DiscordChatExporter.Cli/Commands/ExportGuildCommand.cs +++ b/DiscordChatExporter.Cli/Commands/ExportGuildCommand.cs @@ -21,13 +21,6 @@ public class ExportGuildCommand : ExportCommandBase [CommandOption("include-vc", Description = "Include voice channels.")] public bool IncludeVoiceChannels { get; init; } = true; - [CommandOption( - "include-threads", - Description = "Which types of threads should be included.", - Converter = typeof(ThreadInclusionModeBindingConverter) - )] - public ThreadInclusionMode ThreadInclusionMode { get; init; } = ThreadInclusionMode.None; - public override async ValueTask ExecuteAsync(IConsole console) { await base.ExecuteAsync(console); @@ -66,40 +59,6 @@ public class ExportGuildCommand : ExportCommandBase await console.Output.WriteLineAsync($"Fetched {fetchedChannelsCount} channel(s)."); - // Threads - if (ThreadInclusionMode != ThreadInclusionMode.None) - { - await console.Output.WriteLineAsync("Fetching threads..."); - - var fetchedThreadsCount = 0; - await console - .CreateStatusTicker() - .StartAsync( - "...", - async ctx => - { - await foreach ( - var thread in Discord.GetGuildThreadsAsync( - GuildId, - ThreadInclusionMode == ThreadInclusionMode.All, - Before, - After, - cancellationToken - ) - ) - { - channels.Add(thread); - - ctx.Status(Markup.Escape($"Fetched '{thread.GetHierarchicalName()}'.")); - - fetchedThreadsCount++; - } - } - ); - - await console.Output.WriteLineAsync($"Fetched {fetchedThreadsCount} thread(s)."); - } - await ExportAsync(console, channels); } } diff --git a/DiscordChatExporter.Core/Discord/DiscordClient.cs b/DiscordChatExporter.Core/Discord/DiscordClient.cs index a607113c..134ac01f 100644 --- a/DiscordChatExporter.Core/Discord/DiscordClient.cs +++ b/DiscordChatExporter.Core/Discord/DiscordClient.cs @@ -305,7 +305,31 @@ public class DiscordClient( if (guildId == Guild.DirectMessages.Id) yield break; - var channels = (await GetGuildChannelsAsync(guildId, cancellationToken)) + var channels = await GetGuildChannelsAsync(guildId, cancellationToken); + + foreach ( + var channel in await GetChannelThreadsAsync( + channels, + includeArchived, + before, + after, + cancellationToken + ) + ) + { + yield return channel; + } + } + + public async IAsyncEnumerable GetChannelThreadsAsync( + IEnumerable channels, + bool includeArchived = false, + Snowflake? before = null, + Snowflake? after = null, + [EnumeratorCancellation] CancellationToken cancellationToken = default + ) + { + Channel[] filteredChannels = channels // Categories cannot have threads .Where(c => !c.IsCategory) // Voice channels cannot have threads @@ -322,7 +346,7 @@ public class DiscordClient( // User accounts can only fetch threads using the search endpoint if (await ResolveTokenKindAsync(cancellationToken) == TokenKind.User) { - foreach (var channel in channels) + foreach (var channel in filteredChannels) { // Either include both active and archived threads, or only active threads foreach ( @@ -378,9 +402,14 @@ public class DiscordClient( // Bot accounts can only fetch threads using the threads endpoint else { + var guilds = new HashSet(); + foreach (var channel in filteredChannels) + guilds.Add(channel.GuildId); + // Active threads + foreach (var guildId in guilds) { - var parentsById = channels.ToDictionary(c => c.Id); + var parentsById = filteredChannels.ToDictionary(c => c.Id); var response = await GetJsonResponseAsync( $"guilds/{guildId}/threads/active", @@ -395,14 +424,15 @@ public class DiscordClient( ?.Pipe(Snowflake.Parse) .Pipe(parentsById.GetValueOrDefault); - yield return Channel.Parse(threadJson, parent); + if (filteredChannels.Contains(parent)) + yield return Channel.Parse(threadJson, parent); } } // Archived threads if (includeArchived) { - foreach (var channel in channels) + foreach (var channel in filteredChannels) { foreach (var archiveType in new[] { "public", "private" }) {