Explorar el Código

refactor(middlewares): move message and user extraction to utils module

Librellium hace 2 días
padre
commit
f1df8c9e45

+ 17 - 0
anonflow/bot/middlewares/__init__.py

@@ -0,0 +1,17 @@
+from .user import (
+    UserBannedMiddleware,
+    UserContextMiddleware,
+    UserLanguageMiddleware,
+    UserNotRegisteredMiddleware,
+    UserSubscriptionMiddleware,
+    UserThrottlingMiddleware
+)
+
+__all__ = [
+    "UserBannedMiddleware",
+    "UserContextMiddleware",
+    "UserLanguageMiddleware",
+    "UserNotRegisteredMiddleware",
+    "UserSubscriptionMiddleware",
+    "UserThrottlingMiddleware"
+]

+ 6 - 3
anonflow/bot/middlewares/user/banned.py

@@ -5,6 +5,8 @@ from anonflow.bot.transport.types import RequestContext
 from anonflow.interfaces import UserResponsesPort
 from anonflow.services import ModeratorService
 
+from .utils import extract_message, extract_user
+
 
 class UserBannedMiddleware(BaseMiddleware):
     def __init__(
@@ -16,9 +18,10 @@ class UserBannedMiddleware(BaseMiddleware):
         self._moderator_service = moderator_service
 
     async def __call__(self, handler, event, data):
-        message = getattr(event, "message", None)
-        if isinstance(message, Message):
-            if await self._moderator_service.is_banned(message.chat.id):
+        message = extract_message(event)
+        from_user = extract_user(event)
+        if isinstance(message, Message) and from_user:
+            if await self._moderator_service.is_banned(from_user.id):
                 await self._responses_port.user_banned(
                     RequestContext(message.chat.id, data["user_language"])
                 )

+ 5 - 4
anonflow/bot/middlewares/user/context.py

@@ -1,8 +1,9 @@
 from aiogram import BaseMiddleware
-from aiogram.types import Message
 
 from anonflow.services import UserService
 
+from .utils import extract_user
+
 
 class UserContextMiddleware(BaseMiddleware):
     def __init__(self, user_service: UserService):
@@ -13,8 +14,8 @@ class UserContextMiddleware(BaseMiddleware):
     async def __call__(self, handler, event, data):
         data["user"] = None
 
-        message = getattr(event, "message", None)
-        if isinstance(message, Message) and message.from_user:
-            data["user"] = await self._user_service.get(message.from_user.id)
+        from_user = extract_user(event)
+        if from_user:
+            data["user"] = await self._user_service.get(from_user.id)
 
         return await handler(event, data)

+ 9 - 6
anonflow/bot/middlewares/user/language.py

@@ -1,19 +1,22 @@
 from aiogram import BaseMiddleware
-from aiogram.types import Message
+
+from .utils import extract_user
 
 
 class UserLanguageMiddleware(BaseMiddleware):
-    def __init__(self):
+    def __init__(self, fallback_language: str):
         super().__init__()
 
+        self._fallback_language = fallback_language
+
     async def __call__(self, handler, event, data):
-        data["user_language"] = None
+        data["user_language"] = self._fallback_language
 
-        message = getattr(event, "message", None)
-        if isinstance(message, Message) and message.from_user:
+        from_user = extract_user(event)
+        if from_user:
             user = data.get("user")
             data["user_language"] = (
-                user.language if user else message.from_user.language_code
+                user.language if user else from_user.language_code
             )
 
         return await handler(event, data)

+ 3 - 1
anonflow/bot/middlewares/user/not_registered.py

@@ -5,6 +5,8 @@ from aiogram.types import Message
 from anonflow.bot.transport.types import RequestContext
 from anonflow.interfaces import UserResponsesPort
 
+from .utils import extract_message
+
 
 class UserNotRegisteredMiddleware(BaseMiddleware):
     def __init__(self, responses_port: UserResponsesPort):
@@ -13,7 +15,7 @@ class UserNotRegisteredMiddleware(BaseMiddleware):
         self._responses_port = responses_port
 
     async def __call__(self, handler, event, data):
-        message = getattr(event, "message", None)
+        message = extract_message(event)
         if isinstance(message, Message) and message.chat.type == ChatType.PRIVATE:
             text = message.text or message.caption or ""
 

+ 7 - 5
anonflow/bot/middlewares/user/subscription.py

@@ -7,6 +7,8 @@ from aiogram.types import ChatIdUnion, Message
 from anonflow.bot.transport.types import RequestContext
 from anonflow.interfaces import UserResponsesPort
 
+from .utils import extract_message, extract_user
+
 
 class UserSubscriptionMiddleware(BaseMiddleware):
     def __init__(
@@ -18,16 +20,16 @@ class UserSubscriptionMiddleware(BaseMiddleware):
         self._channel_ids = channel_ids
 
     async def __call__(self, handler, event, data):
-        message = getattr(event, "message", None)
+        message = extract_message(event)
+        from_user = extract_user(event)
         if (
             isinstance(message, Message)
+            and from_user is not None
+            and message.bot is not None
             and message.chat.type == ChatType.PRIVATE
-            and message.from_user
-            and message.bot
         ):
-            user_id = message.from_user.id
             for channel_id in self._channel_ids:
-                member = await message.bot.get_chat_member(channel_id, user_id)
+                member = await message.bot.get_chat_member(channel_id, from_user.id)
                 if member.status in (ChatMemberStatus.KICKED, ChatMemberStatus.LEFT):
                     await self._responses_port.user_subscription_required(
                         RequestContext(message.chat.id, data["user_language"])

+ 13 - 7
anonflow/bot/middlewares/user/throttling.py

@@ -3,24 +3,29 @@ import time
 from typing import Dict, Iterable, Optional
 
 from aiogram import BaseMiddleware
+from aiogram.enums import ChatType
 from aiogram.types import ChatIdUnion, Message
 
 from anonflow.bot.transport.types import RequestContext
 from anonflow.interfaces import UserResponsesPort
 
+from .utils import extract_message
+
 
 class UserThrottlingMiddleware(BaseMiddleware):
     def __init__(
         self,
         responses_port: UserResponsesPort,
         delay: float,
-        allowed_chat_ids: Optional[Iterable[ChatIdUnion]] = None,
+        ignored_chat_ids: Optional[Iterable[Optional[ChatIdUnion]]] = None,
+        ignored_commands: Optional[Iterable[str]] = None
     ):
         super().__init__()
 
         self._responses_port = responses_port
         self._delay = delay
-        self._allowed_chat_ids = allowed_chat_ids
+        self._ignored_chat_ids = tuple(ignored_chat_ids or ())
+        self._ignored_commands = tuple(ignored_commands or ())
 
         self._user_times: Dict[int, float] = {}
         self._user_locks: Dict[int, asyncio.Lock] = {}
@@ -28,13 +33,14 @@ class UserThrottlingMiddleware(BaseMiddleware):
         self._lock = asyncio.Lock()
 
     async def __call__(self, handler, event, data):
-        message = getattr(event, "message", None)
-        if isinstance(message, Message) and (
-            self._allowed_chat_ids is not None
-            and message.chat.id not in self._allowed_chat_ids
+        message = extract_message(event)
+        if (
+            isinstance(message, Message)
+            and message.chat.id not in self._ignored_chat_ids
+            and message.chat.type == ChatType.PRIVATE
         ):
             text = message.text or message.caption or ""
-            if not text.startswith("/"):
+            if not text.startswith(self._ignored_commands):
                 async with self._lock:
                     user_lock = self._user_locks.setdefault(
                         message.chat.id, asyncio.Lock()

+ 16 - 0
anonflow/bot/middlewares/user/utils.py

@@ -0,0 +1,16 @@
+from typing import Optional
+
+from aiogram.types import Message, User
+
+
+def extract_message(event) -> Optional[Message]:
+    if message := getattr(event, "message", None):
+        return message
+    if callback_query := getattr(event, "callback_query", None):
+        return callback_query.message
+
+def extract_user(event) -> Optional[User]:
+    if message := getattr(event, "message", None):
+        return message.from_user
+    if callback_query := getattr(event, "callback_query", None):
+        return callback_query.from_user