|
| 1 | +import discord |
| 2 | +from pydis_core.utils import logging |
| 3 | +from sqlalchemy import update |
| 4 | +from sqlalchemy.ext.asyncio import AsyncSession |
| 5 | + |
| 6 | +from metricity import models |
| 7 | +from metricity.bot import Bot |
| 8 | +from metricity.config import BotConfig |
| 9 | +from metricity.database import async_session |
| 10 | + |
| 11 | +log = logging.get_logger(__name__) |
| 12 | + |
| 13 | + |
| 14 | +def insert_thread(thread: discord.Thread, sess: AsyncSession) -> None: |
| 15 | + """Insert the given thread to the database session.""" |
| 16 | + sess.add(models.Thread( |
| 17 | + id=str(thread.id), |
| 18 | + parent_channel_id=str(thread.parent_id), |
| 19 | + name=thread.name, |
| 20 | + archived=thread.archived, |
| 21 | + auto_archive_duration=thread.auto_archive_duration, |
| 22 | + locked=thread.locked, |
| 23 | + type=thread.type.name, |
| 24 | + created_at=thread.created_at, |
| 25 | + )) |
| 26 | + |
| 27 | + |
| 28 | +async def sync_message(message: discord.Message, sess: AsyncSession, *, from_thread: bool) -> None: |
| 29 | + """Sync the given message with the database.""" |
| 30 | + if await sess.get(models.Message, str(message.id)): |
| 31 | + return |
| 32 | + |
| 33 | + args = { |
| 34 | + "id": str(message.id), |
| 35 | + "channel_id": str(message.channel.id), |
| 36 | + "author_id": str(message.author.id), |
| 37 | + "created_at": message.created_at, |
| 38 | + } |
| 39 | + |
| 40 | + if from_thread: |
| 41 | + thread = message.channel |
| 42 | + args["channel_id"] = str(thread.parent_id) |
| 43 | + args["thread_id"] = str(thread.id) |
| 44 | + |
| 45 | + sess.add(models.Message(**args)) |
| 46 | + |
| 47 | + |
| 48 | +async def sync_channels(bot: Bot, guild: discord.Guild) -> None: |
| 49 | + """Sync channels and categories with the database.""" |
| 50 | + bot.channel_sync_in_progress.clear() |
| 51 | + |
| 52 | + log.info("Beginning category synchronisation process") |
| 53 | + |
| 54 | + async with async_session() as sess: |
| 55 | + for channel in guild.channels: |
| 56 | + if isinstance(channel, discord.CategoryChannel): |
| 57 | + if existing_cat := await sess.get(models.Category, str(channel.id)): |
| 58 | + existing_cat.name = channel.name |
| 59 | + else: |
| 60 | + sess.add(models.Category(id=str(channel.id), name=channel.name, deleted=False)) |
| 61 | + |
| 62 | + await sess.commit() |
| 63 | + |
| 64 | + log.info("Category synchronisation process complete, synchronising deleted categories") |
| 65 | + |
| 66 | + async with async_session() as sess: |
| 67 | + await sess.execute( |
| 68 | + update(models.Category) |
| 69 | + .where(~models.Category.id.in_( |
| 70 | + [str(channel.id) for channel in guild.channels if isinstance(channel, discord.CategoryChannel)], |
| 71 | + )) |
| 72 | + .values(deleted=True), |
| 73 | + ) |
| 74 | + await sess.commit() |
| 75 | + |
| 76 | + log.info("Deleted category synchronisation process complete, synchronising channels") |
| 77 | + |
| 78 | + async with async_session() as sess: |
| 79 | + for channel in guild.channels: |
| 80 | + if channel.category and channel.category.id in BotConfig.ignore_categories: |
| 81 | + continue |
| 82 | + |
| 83 | + if not isinstance(channel, discord.CategoryChannel): |
| 84 | + category_id = str(channel.category.id) if channel.category else None |
| 85 | + # Cast to bool so is_staff is False if channel.category is None |
| 86 | + is_staff = channel.id in BotConfig.staff_channels or bool( |
| 87 | + channel.category and channel.category.id in BotConfig.staff_categories, |
| 88 | + ) |
| 89 | + if db_chan := await sess.get(models.Channel, str(channel.id)): |
| 90 | + db_chan.name = channel.name |
| 91 | + else: |
| 92 | + sess.add(models.Channel( |
| 93 | + id=str(channel.id), |
| 94 | + name=channel.name, |
| 95 | + category_id=category_id, |
| 96 | + is_staff=is_staff, |
| 97 | + deleted=False, |
| 98 | + )) |
| 99 | + |
| 100 | + await sess.commit() |
| 101 | + |
| 102 | + log.info("Channel synchronisation process complete, synchronising deleted channels") |
| 103 | + |
| 104 | + async with async_session() as sess: |
| 105 | + await sess.execute( |
| 106 | + update(models.Channel) |
| 107 | + .where(~models.Channel.id.in_([str(channel.id) for channel in guild.channels])) |
| 108 | + .values(deleted=True), |
| 109 | + ) |
| 110 | + await sess.commit() |
| 111 | + |
| 112 | + log.info("Deleted channel synchronisation process complete, synchronising threads") |
| 113 | + |
| 114 | + async with async_session() as sess: |
| 115 | + for thread in guild.threads: |
| 116 | + if thread.parent and thread.parent.category: |
| 117 | + if thread.parent.category.id in BotConfig.ignore_categories: |
| 118 | + continue |
| 119 | + else: |
| 120 | + # This is a forum channel, not currently supported by Discord.py. Ignore it. |
| 121 | + continue |
| 122 | + |
| 123 | + if db_thread := await sess.get(models.Thread, str(thread.id)): |
| 124 | + db_thread.name = thread.name |
| 125 | + db_thread.archived = thread.archived |
| 126 | + db_thread.auto_archive_duration = thread.auto_archive_duration |
| 127 | + db_thread.locked = thread.locked |
| 128 | + db_thread.type = thread.type.name |
| 129 | + else: |
| 130 | + insert_thread(thread, sess) |
| 131 | + await sess.commit() |
| 132 | + |
| 133 | + log.info("Thread synchronisation process complete, finished synchronising guild.") |
| 134 | + bot.channel_sync_in_progress.set() |
| 135 | + |
| 136 | + |
| 137 | +async def sync_thread_archive_state(guild: discord.Guild) -> None: |
| 138 | + """Sync the archive state of all threads in the database with the state in guild.""" |
| 139 | + active_thread_ids = [str(thread.id) for thread in guild.threads] |
| 140 | + |
| 141 | + async with async_session() as sess: |
| 142 | + await sess.execute( |
| 143 | + update(models.Thread) |
| 144 | + .where(models.Thread.id.in_(active_thread_ids)) |
| 145 | + .values(archived=False), |
| 146 | + ) |
| 147 | + await sess.execute( |
| 148 | + update(models.Thread) |
| 149 | + .where(~models.Thread.id.in_(active_thread_ids)) |
| 150 | + .values(archived=True), |
| 151 | + ) |
| 152 | + await sess.commit() |
0 commit comments