Przeglądaj źródła

Refactor MediaRouter to use EventHandler

Librellium 3 miesięcy temu
rodzic
commit
91782be0bc
1 zmienionych plików z 81 dodań i 75 usunięć
  1. 81 75
      anonflow/bot/routers/media.py

+ 81 - 75
anonflow/bot/routers/media.py

@@ -1,34 +1,38 @@
 import asyncio
 from asyncio import CancelledError
 from typing import Dict, List, Optional
-import logging
 
 from aiogram import Bot, F, Router
 from aiogram.enums import ChatType
 from aiogram.exceptions import TelegramBadRequest, TelegramForbiddenError
 from aiogram.types import InputMediaPhoto, InputMediaVideo, Message
 
+from anonflow.bot.utils.event_handler import EventHandler
 from anonflow.bot.utils.message_manager import MessageManager
 from anonflow.bot.utils.template_renderer import TemplateRenderer
 from anonflow.config import Config
-from anonflow.moderation import ModerationExecutor
+from anonflow.moderation import ModerationDecisionEvent, ModerationExecutor
 
 
 class MediaRouter(Router):
-    def __init__(self,
-                 config: Config,
-                 message_manager: MessageManager,
-                 template_renderer: TemplateRenderer,
-                 moderation_executor: Optional[ModerationExecutor] = None):
+    def __init__(
+        self,
+        config: Config,
+        message_manager: MessageManager,
+        template_renderer: TemplateRenderer,
+        moderation_executor: Optional[ModerationExecutor] = None,
+        event_handler: Optional[EventHandler] = None,
+    ):
         super().__init__()
 
         self.config = config
         self.message_manager = message_manager
         self.renderer = template_renderer
         self.executor = moderation_executor
+        self.event_handler = event_handler
 
-        self.media_groups: Dict[int, List[str]] = {}
-        self.media_groups_tasks: Dict[int, asyncio.Task] = {}
+        self.media_groups: Dict[str, List[Message]] = {}
+        self.media_groups_tasks: Dict[str, asyncio.Task] = {}
         self.media_groups_lock = asyncio.Lock()
 
         self._register_handlers()
@@ -36,36 +40,33 @@ class MediaRouter(Router):
     def _register_handlers(self):
         @self.message(F.photo | F.video)
         async def on_photo(message: Message, bot: Bot):
-            if ("photo" not in self.config.forwarding.types\
-                and "video" not in self.config.forwarding.types)\
-                    or message.chat.type != ChatType.PRIVATE:
+            if message.chat.type != ChatType.PRIVATE:
                 return
 
             def can_send_media(msgs: List[Message]):
                 photos = len([msg for msg in msgs if msg.photo])
                 videos = len([msg for msg in msgs if msg.video])
 
-                return (photos and "photo" in self.config.forwarding.types) or (videos and "video" in self.config.forwarding.types)
+                return (photos and "photo" in self.config.forwarding.types) or (
+                    videos and "video" in self.config.forwarding.types
+                )
 
-            def get_media(msg: Message):
-                caption = self.renderer.render("messages/channel/media.j2", message=msg)
-                parse_mode = "HTML" if msg.caption else None
+            async def get_media(msg: Message):
+                caption = await self.renderer.render(
+                    "messages/channel/media.j2", message=msg
+                )
 
                 if msg.photo and "photo" in self.config.forwarding.types:
-                    return InputMediaPhoto(
-                        media=msg.photo[-1].file_id,
-                        caption=caption,
-                        parse_mode=parse_mode
-                    )
+                    return InputMediaPhoto(media=msg.photo[-1].file_id, caption=caption)
                 elif msg.video and "video" in self.config.forwarding.types:
-                    return InputMediaVideo(
-                        media=msg.video.file_id,
-                        caption=caption,
-                        parse_mode=parse_mode
-                    )
+                    return InputMediaVideo(media=msg.video.file_id, caption=caption)
 
             async def process_messages(messages: list[Message]):
-                if not messages: return
+                if not messages:
+                    return
+
+                moderation_chat_ids = self.config.forwarding.moderation_chat_ids
+                publication_channel_ids = self.config.forwarding.publication_channel_ids
 
                 reply_to_message_id = messages[0].message_id
 
@@ -75,38 +76,29 @@ class MediaRouter(Router):
                         moderation_passed = not moderation
 
                         group_message_id = None
-                        targets = {
-                            self.config.forwarding.moderation_chat_id: True
-                        }
 
-                        sent_message = await message.answer(
-                            await self.renderer.render("messages/users/moderation/pending.j2", message=message)
-                        )
+                        targets = {}
+                        if moderation_chat_ids:
+                            for chat_id in moderation_chat_ids:
+                                targets[chat_id] = True
+
                         if len(messages) > 1:
                             media = []
                             for msg in messages:
                                 if moderation and msg.caption:
-                                    async for event in self.executor.process_message(msg.caption):
-                                        if event.type == "moderation_decision":
-                                            if event.result.status == "APPROVE":
-                                                moderation_passed = True
-                                            elif event.result.status == "REJECT":
-                                                await message.answer(
-                                                    await self.renderer.render("messages/users/moderation/rejected.j2", message=message)
-                                                )
-
-                                    await sent_message.delete()
+                                    async for event in self.executor.process_message(msg):
+                                        if isinstance(event, ModerationDecisionEvent):
+                                            moderation_passed = event.approved
+                                        await self.event_handler.handle(event, message)
 
-                                media.append(get_media(msg))
+                                media.append(await get_media(msg))
 
-                            if moderation_passed:
-                                targets[self.config.forwarding.publication_chat_id] = False
+                            if publication_channel_ids and moderation_passed:
+                                for channel_id in publication_channel_ids:
+                                    targets[channel_id] = False
 
                             for target, save_message_id in targets.items():
-                                messages = await bot.send_media_group(
-                                    target,
-                                    media
-                                )
+                                messages = await bot.send_media_group(target, media)
 
                                 if save_message_id:
                                     group_message_id = messages[0].message_id
@@ -115,41 +107,52 @@ class MediaRouter(Router):
                             caption = msg.caption
 
                             if moderation and caption:
-                                async for event in self.executor.process_message(caption):
-                                    if event.type == "moderation_decision":
-                                        if event.result.status == "APPROVE":
-                                            moderation_passed = True
-                                        elif event.result.status == "REJECT":
-                                            await message.answer(
-                                                await self.renderer.render("messages/users/moderation/rejected.j2", message=message)
-                                            )
+                                async for event in self.executor.process_message(msg):
+                                    if isinstance(event, ModerationDecisionEvent):
+                                        moderation_passed = event.approved
+                                    await self.event_handler.handle(event, message)
 
-                                await sent_message.delete()
-
-                            targets[self.config.forwarding.publication_chat_id] = False
+                            if publication_channel_ids and moderation_passed:
+                                for channel_id in publication_channel_ids:
+                                    targets[channel_id] = False
 
                             func = bot.send_photo if msg.photo else bot.send_video
-                            file_id = msg.photo[-1].file_id if msg.photo else msg.video.file_id
+                            file_id = (
+                                msg.photo[-1].file_id
+                                if msg.photo
+                                else msg.video.file_id
+                            )
 
                             for target, save_message_id in targets.items():
-                                m = await func(
-                                    target,
-                                    file_id,
-                                    caption=await self.renderer.render("messages/channel/media.j2", message=msg),
-                                    parse_mode="HTML"
-                                )
+                                msg_id = (
+                                    await func(
+                                        target,
+                                        file_id,
+                                        caption=await self.renderer.render(
+                                            "messages/channel/media.j2", message=msg
+                                        ),
+                                    )
+                                ).message_id
 
                                 if save_message_id:
-                                    group_message_id = m.message_id
+                                    group_message_id = msg_id
 
-                        self.message_manager.add(reply_to_message_id, group_message_id, message.chat.id)
+                        self.message_manager.add(
+                            reply_to_message_id, group_message_id, message.chat.id
+                        )
                         if moderation_passed:
                             await message.answer(
-                                await self.renderer.render("messages/users/send/success.j2", message=message)
+                                await self.renderer.render(
+                                    "messages/users/send/success.j2", message=message
+                                )
                             )
                 except (TelegramBadRequest, TelegramForbiddenError) as e:
                     await message.answer(
-                        await self.renderer.render("messages/users/send/failure.j2", message=message, exception=e)
+                        await self.renderer.render(
+                            "messages/users/send/failure.j2",
+                            message=message,
+                            exception=e,
+                        )
                     )
 
             media_group_id = message.media_group_id
@@ -168,9 +171,12 @@ class MediaRouter(Router):
                 self.media_groups.setdefault(media_group_id, []).append(message)
 
                 task = self.media_groups_tasks.get(media_group_id)
-                if task: task.cancel()
+                if task:
+                    task.cancel()
 
-                self.media_groups_tasks[media_group_id] = asyncio.create_task(await_media_group())
+                self.media_groups_tasks[media_group_id] = asyncio.create_task(
+                    await_media_group()
+                )
                 return
 
-            await process_messages([message])
+            await process_messages([message])