瀏覽代碼

refactor!(middlewares): add UserContextMiddleware and move all user middlewares to submodule

Librellium 1 周之前
父節點
當前提交
90635738e5

+ 0 - 13
anonflow/bot/middlewares/__init__.py

@@ -1,13 +0,0 @@
-from .banned import BannedMiddleware
-from .language import LanguageMiddleware
-from .not_registered import NotRegisteredMiddleware
-from .subscription import SubscriptionMiddleware
-from .throttling import ThrottlingMiddleware
-
-__all__ = [
-    "BannedMiddleware",
-    "LanguageMiddleware",
-    "NotRegisteredMiddleware",
-    "SubscriptionMiddleware",
-    "ThrottlingMiddleware"
-]

+ 15 - 0
anonflow/bot/middlewares/user/__init__.py

@@ -0,0 +1,15 @@
+from .banned import UserBannedMiddleware
+from .context import UserContextMiddleware
+from .language import UserLanguageMiddleware
+from .not_registered import UserNotRegisteredMiddleware
+from .subscription import UserSubscriptionMiddleware
+from .throttling import UserThrottlingMiddleware
+
+__all__ = [
+    "UserBannedMiddleware",
+    "UserContextMiddleware",
+    "UserLanguageMiddleware",
+    "UserNotRegisteredMiddleware",
+    "UserSubscriptionMiddleware",
+    "UserThrottlingMiddleware"
+]

+ 2 - 2
anonflow/bot/middlewares/banned.py → anonflow/bot/middlewares/user/banned.py

@@ -1,12 +1,12 @@
 from aiogram import BaseMiddleware
 from aiogram.types import Message
 
+from anonflow.bot.transport.types import RequestContext
 from anonflow.interfaces import UserResponsesPort
 from anonflow.services import ModeratorService
-from anonflow.services.transport.types import RequestContext
 
 
-class BannedMiddleware(BaseMiddleware):
+class UserBannedMiddleware(BaseMiddleware):
     def __init__(self, responses_port: UserResponsesPort, moderator_service: ModeratorService):
         super().__init__()
 

+ 20 - 0
anonflow/bot/middlewares/user/context.py

@@ -0,0 +1,20 @@
+from aiogram import BaseMiddleware
+from aiogram.types import Message
+
+from anonflow.services import UserService
+
+
+class UserContextMiddleware(BaseMiddleware):
+    def __init__(self, user_service: UserService):
+        super().__init__()
+
+        self._user_service = user_service
+
+    async def __call__(self, handler, event, data):
+        data["user"] = None
+
+        message = getattr(event, "message", None)
+        if isinstance(message, Message) and message.from_user:
+            data["user"] = await self._user_service.get(message.from_user.id)
+
+        return await handler(event, data)

+ 3 - 7
anonflow/bot/middlewares/language.py → anonflow/bot/middlewares/user/language.py

@@ -1,21 +1,17 @@
 from aiogram import BaseMiddleware
 from aiogram.types import Message
 
-from anonflow.services import UserService
 
-
-class LanguageMiddleware(BaseMiddleware):
-    def __init__(self, user_service: UserService):
+class UserLanguageMiddleware(BaseMiddleware):
+    def __init__(self):
         super().__init__()
 
-        self._user_service = user_service
-
     async def __call__(self, handler, event, data):
         data["user_language"] = None
 
         message = getattr(event, "message", None)
         if isinstance(message, Message) and message.from_user:
-            user = await self._user_service.get(message.from_user.id)
+            user = data.get("user")
             data["user_language"] = (
                 user.language
                 if user else message.from_user.language_code

+ 4 - 6
anonflow/bot/middlewares/not_registered.py → anonflow/bot/middlewares/user/not_registered.py

@@ -2,24 +2,22 @@ from aiogram import BaseMiddleware
 from aiogram.enums import ChatType
 from aiogram.types import Message
 
+from anonflow.bot.transport.types import RequestContext
 from anonflow.interfaces import UserResponsesPort
-from anonflow.services import UserService
-from anonflow.services.transport.types import RequestContext
 
 
-class NotRegisteredMiddleware(BaseMiddleware):
-    def __init__(self, responses_port: UserResponsesPort, user_service: UserService):
+class UserNotRegisteredMiddleware(BaseMiddleware):
+    def __init__(self, responses_port: UserResponsesPort):
         super().__init__()
 
         self._responses_port = responses_port
-        self._user_service = user_service
 
     async def __call__(self, handler, event, data):
         message = getattr(event, "message", None)
         if isinstance(message, Message) and message.chat.type == ChatType.PRIVATE:
             text = message.text or message.caption or ""
 
-            is_user_exists = await self._user_service.has(message.chat.id)
+            is_user_exists = data.get("user") is not None
             if not is_user_exists and not text.startswith("/start"):
                 await self._responses_port.user_not_registered(RequestContext(message.chat.id, data["user_language"]))
                 return

+ 12 - 7
anonflow/bot/middlewares/subscription.py → anonflow/bot/middlewares/user/subscription.py

@@ -1,15 +1,15 @@
-from typing import Tuple
+from typing import Iterable
 
 from aiogram import BaseMiddleware
 from aiogram.enums import ChatMemberStatus, ChatType
 from aiogram.types import ChatIdUnion, Message
 
+from anonflow.bot.transport.types import RequestContext
 from anonflow.interfaces import UserResponsesPort
-from anonflow.services.transport.types import RequestContext
 
 
-class SubscriptionMiddleware(BaseMiddleware):
-    def __init__(self, responses_port: UserResponsesPort, channel_ids: Tuple[ChatIdUnion]):
+class UserSubscriptionMiddleware(BaseMiddleware):
+    def __init__(self, responses_port: UserResponsesPort, channel_ids: Iterable[ChatIdUnion]):
         super().__init__()
 
         self._responses_port = responses_port
@@ -17,10 +17,15 @@ class SubscriptionMiddleware(BaseMiddleware):
 
     async def __call__(self, handler, event, data):
         message = getattr(event, "message", None)
-        if isinstance(message, Message) and message.chat.type == ChatType.PRIVATE:
-            user_id = message.from_user.id # type: ignore
+        if (
+            isinstance(message, Message)
+            and message.chat.type == ChatType.PRIVATE
+            and message.from_user
+            and message.bot
+        ):
+            user_id = message.from_user.id
             for channel_id in self._channel_ids:
-                member = await message.bot.get_chat_member(channel_id, user_id) # type: ignore
+                member = await message.bot.get_chat_member(channel_id, user_id)
                 if member.status in (ChatMemberStatus.KICKED, ChatMemberStatus.LEFT):
                     await self._responses_port.user_subscription_required(RequestContext(message.chat.id, data["user_language"]))
                     return

+ 3 - 3
anonflow/bot/middlewares/throttling.py → anonflow/bot/middlewares/user/throttling.py

@@ -5,16 +5,16 @@ from typing import Dict, Iterable, Optional
 from aiogram import BaseMiddleware
 from aiogram.types import ChatIdUnion, Message
 
+from anonflow.bot.transport.types import RequestContext
 from anonflow.interfaces import UserResponsesPort
-from anonflow.services.transport.types import RequestContext
 
 
-class ThrottlingMiddleware(BaseMiddleware):
+class UserThrottlingMiddleware(BaseMiddleware):
     def __init__(
         self,
         responses_port: UserResponsesPort,
         delay: float,
-        allowed_chat_ids: Optional[Iterable[ChatIdUnion]]
+        allowed_chat_ids: Optional[Iterable[ChatIdUnion]] = None
     ):
         super().__init__()