|
1 | 1 | import discord |
2 | 2 | from pydis_core.utils import logging |
| 3 | +from sqlalchemy import update |
3 | 4 | from sqlalchemy.ext.asyncio import AsyncSession |
4 | 5 |
|
5 | 6 | from metricity import models |
| 7 | +from metricity.bot import Bot |
| 8 | +from metricity.config import BotConfig |
| 9 | +from metricity.database import async_session |
6 | 10 |
|
7 | 11 | log = logging.get_logger(__name__) |
8 | 12 |
|
@@ -39,3 +43,110 @@ async def sync_message(message: discord.Message, sess: AsyncSession, *, from_thr |
39 | 43 | args["thread_id"] = str(thread.id) |
40 | 44 |
|
41 | 45 | sess.add(models.Message(**args)) |
| 46 | + |
| 47 | + |
| 48 | +async def sync_channels(bot: Bot, guild: discord.Guild) -> None: |
| 49 | + """Sync channels and categories with the database.""" |
| 50 | + bot.channel_sync_in_progress.clear() |
| 51 | + |
| 52 | + log.info("Beginning category synchronisation process") |
| 53 | + |
| 54 | + async with async_session() as sess: |
| 55 | + for channel in guild.channels: |
| 56 | + if isinstance(channel, discord.CategoryChannel): |
| 57 | + if existing_cat := await sess.get(models.Category, str(channel.id)): |
| 58 | + existing_cat.name = channel.name |
| 59 | + else: |
| 60 | + sess.add(models.Category(id=str(channel.id), name=channel.name, deleted=False)) |
| 61 | + |
| 62 | + await sess.commit() |
| 63 | + |
| 64 | + log.info("Category synchronisation process complete, synchronising deleted categories") |
| 65 | + |
| 66 | + async with async_session() as sess: |
| 67 | + await sess.execute( |
| 68 | + update(models.Category) |
| 69 | + .where(~models.Category.id.in_( |
| 70 | + [str(channel.id) for channel in guild.channels if isinstance(channel, discord.CategoryChannel)], |
| 71 | + )) |
| 72 | + .values(deleted=True), |
| 73 | + ) |
| 74 | + await sess.commit() |
| 75 | + |
| 76 | + log.info("Deleted category synchronisation process complete, synchronising channels") |
| 77 | + |
| 78 | + async with async_session() as sess: |
| 79 | + for channel in guild.channels: |
| 80 | + if channel.category and channel.category.id in BotConfig.ignore_categories: |
| 81 | + continue |
| 82 | + |
| 83 | + if not isinstance(channel, discord.CategoryChannel): |
| 84 | + category_id = str(channel.category.id) if channel.category else None |
| 85 | + # Cast to bool so is_staff is False if channel.category is None |
| 86 | + is_staff = channel.id in BotConfig.staff_channels or bool( |
| 87 | + channel.category and channel.category.id in BotConfig.staff_categories, |
| 88 | + ) |
| 89 | + if db_chan := await sess.get(models.Channel, str(channel.id)): |
| 90 | + db_chan.name = channel.name |
| 91 | + else: |
| 92 | + sess.add(models.Channel( |
| 93 | + id=str(channel.id), |
| 94 | + name=channel.name, |
| 95 | + category_id=category_id, |
| 96 | + is_staff=is_staff, |
| 97 | + deleted=False, |
| 98 | + )) |
| 99 | + |
| 100 | + await sess.commit() |
| 101 | + |
| 102 | + log.info("Channel synchronisation process complete, synchronising deleted channels") |
| 103 | + |
| 104 | + async with async_session() as sess: |
| 105 | + await sess.execute( |
| 106 | + update(models.Channel) |
| 107 | + .where(~models.Channel.id.in_([str(channel.id) for channel in guild.channels])) |
| 108 | + .values(deleted=True), |
| 109 | + ) |
| 110 | + await sess.commit() |
| 111 | + |
| 112 | + log.info("Deleted channel synchronisation process complete, synchronising threads") |
| 113 | + |
| 114 | + async with async_session() as sess: |
| 115 | + for thread in guild.threads: |
| 116 | + if thread.parent and thread.parent.category: |
| 117 | + if thread.parent.category.id in BotConfig.ignore_categories: |
| 118 | + continue |
| 119 | + else: |
| 120 | + # This is a forum channel, not currently supported by Discord.py. Ignore it. |
| 121 | + continue |
| 122 | + |
| 123 | + if db_thread := await sess.get(models.Thread, str(thread.id)): |
| 124 | + db_thread.name = thread.name |
| 125 | + db_thread.archived = thread.archived |
| 126 | + db_thread.auto_archive_duration = thread.auto_archive_duration |
| 127 | + db_thread.locked = thread.locked |
| 128 | + db_thread.type = thread.type.name |
| 129 | + else: |
| 130 | + insert_thread(thread, sess) |
| 131 | + await sess.commit() |
| 132 | + |
| 133 | + log.info("Thread synchronisation process complete, finished synchronising guild.") |
| 134 | + bot.channel_sync_in_progress.set() |
| 135 | + |
| 136 | + |
| 137 | +async def sync_thread_archive_state(guild: discord.Guild) -> None: |
| 138 | + """Sync the archive state of all threads in the database with the state in guild.""" |
| 139 | + active_thread_ids = [str(thread.id) for thread in guild.threads] |
| 140 | + |
| 141 | + async with async_session() as sess: |
| 142 | + await sess.execute( |
| 143 | + update(models.Thread) |
| 144 | + .where(models.Thread.id.in_(active_thread_ids)) |
| 145 | + .values(archived=False), |
| 146 | + ) |
| 147 | + await sess.execute( |
| 148 | + update(models.Thread) |
| 149 | + .where(~models.Thread.id.in_(active_thread_ids)) |
| 150 | + .values(archived=True), |
| 151 | + ) |
| 152 | + await sess.commit() |
0 commit comments