Przeglądaj źródła

Refactor ModerationExecutor to improve code readability

Librellium 3 miesięcy temu
rodzic
commit
adfe37bf22
1 zmienionych plików z 48 dodań i 52 usunięć
  1. 48 52
      anonflow/moderation/executor.py

+ 48 - 52
anonflow/moderation/executor.py

@@ -3,32 +3,32 @@ from typing import AsyncGenerator, Literal
 
 from aiogram import Bot
 from aiogram.exceptions import TelegramBadRequest
+from aiogram.types import Message
 from yarl import URL
 
 from anonflow.bot.utils.template_renderer import TemplateRenderer
 from anonflow.config import Config
 
-from .models import ModerationDecision, ModerationEvent
+from .models import Events, ExecutorDeletionEvent, ModerationDecisionEvent, ModerationStartedEvent
 from .planner import ModerationPlanner
 
 
 class ModerationExecutor:
-    def __init__(self,
-                 config: Config,
-                 bot: Bot,
-                 template_renderer: TemplateRenderer,
-                 planner: ModerationPlanner):
+    def __init__(
+        self,
+        bot: Bot,
+        config: Config,
+        template_renderer: TemplateRenderer,
+        planner: ModerationPlanner,
+    ):
         self._logger = logging.getLogger(__name__)
 
-        self.config = config
         self.bot = bot
+        self.config = config
         self.renderer = template_renderer
 
         self.planner = planner
-        self.planner.set_functions(
-            self.delete_message,
-            self.moderation_decision
-        )
+        self.planner.set_functions(self.delete_message, self.moderation_decision)
 
     async def delete_message(self, message_link: str):
         """
@@ -39,60 +39,56 @@ class ModerationExecutor:
         parsed_url = URL(message_link)
         parsed_path = parsed_url.path.strip("/").split("/")
 
-        moderation_chat_id = self.config.forwarding.moderation_chat_id
         publication_chat_id = self.config.forwarding.publication_chat_id
 
-        if len(parsed_path) == 3 and parsed_path[0] == "c"\
-            and parsed_path[1].replace("-100", "") == str(publication_chat_id).replace("-100", ""):
-                message_id = parsed_path[2]
-                try:
-                    await self.bot.delete_message(publication_chat_id, message_id)
-                    await self.bot.send_message(
-                        moderation_chat_id,
-                        await self.renderer.render("messages/staff/deletion/success.j2", message_id=message_id),
-                        parse_mode="HTML"
-                    )
-                    return ModerationEvent(type="delete_message", result=True)
-                except TelegramBadRequest:
-                    await self.bot.send_message(
-                        moderation_chat_id,
-                        await self.renderer.render("messages/staff/deletion/failure.j2", message_id=message_id),
-                        parse_mode="HTML"
-                    )
-                    return ModerationEvent(type="delete_message", result=False)
-
-    async def moderation_decision(self, status: Literal["APPROVE", "REJECT"], explanation: str):
+        if not publication_chat_id:
+            return ExecutorDeletionEvent(success=False)
+
+        if (
+            len(parsed_path) == 3
+            and parsed_path[0] == "c"
+            and parsed_path[1].replace("-100", "")
+            == str(publication_chat_id).replace("-100", "")
+        ):
+            message_id = parsed_path[2]
+            try:
+                await self.bot.delete_message(publication_chat_id, message_id)
+                return ExecutorDeletionEvent(success=True, message_id=message_id)
+            except TelegramBadRequest:
+                return ExecutorDeletionEvent(success=False, message_id=message_id)
+
+        return ExecutorDeletionEvent(success=False)
+
+    async def moderation_decision(
+        self, status: Literal["APPROVE", "REJECT"], explanation: str
+    ):
         """
-        Processes a message with a moderation decision by status and explanation. 
-        This function must be called whenever there is no exact user request or no other available function 
+        Processes a message with a moderation decision by status and explanation.
+        This function must be called whenever there is no exact user request or no other available function
         matching the user's intent. Status must be either "APPROVE" if the message is allowed, or "REJECT" if it should be blocked.
         """
-        moderation_chat_id = self.config.forwarding.moderation_chat_id
-
-        await self.bot.send_message(
-            moderation_chat_id,
-            await self.renderer.render(
-                f"messages/staff/moderation/{'approved' if status == 'APPROVE' else 'rejected'}.j2",
-                status=status,
-                explanation=explanation
-            ),
-            parse_mode="HTML"
-        )
 
-        if status not in ("APPROVE", "REJECT"):
-            return ModerationEvent(type="moderation_decision", result=ModerationDecision(status="REJECT", explanation=explanation))
+        moderation_map = {"APPROVE": True, "REJECT": False}
+
+        return ModerationDecisionEvent(
+            approved=moderation_map.get(status, False), explanation=explanation
+        )
 
-        return ModerationEvent(type="moderation_decision", result=ModerationDecision(status=status, explanation=explanation))
+    async def process_message(self, message: Message) -> AsyncGenerator[Events, None]:
+        yield ModerationStartedEvent()
 
-    async def process_message(self, message_text: str) -> AsyncGenerator[ModerationEvent, None]:
-        functions = await self.planner.plan(message_text)
+        functions = await self.planner.plan((message.text or message.caption))
         function_names = self.planner.get_function_names()
 
         for func in functions:
             func_name = func.get("name")
             if hasattr(self, func_name) and func_name in function_names:
                 try:
-                    self._logger.info(f"Executing {func_name}({', '.join(map(str, func.get('args')))})")
+                    self._logger.info(
+                        f"Executing {func_name}({', '.join(map(str, func.get('args')))})"
+                    )
                     yield await getattr(self, func_name)(*func.get("args"))
                 except Exception:
-                    self._logger.exception(f"Failed to execute {func_name}({', '.join(map(str, func.get('args')))})")
+                    self._logger.exception(
+                        f"Failed to execute {func_name}({', '.join(map(str, func.get('args')))})"
+                    )