|
|
@@ -3,24 +3,29 @@ import time
|
|
|
from typing import Dict, Iterable, Optional
|
|
|
|
|
|
from aiogram import BaseMiddleware
|
|
|
+from aiogram.enums import ChatType
|
|
|
from aiogram.types import ChatIdUnion, Message
|
|
|
|
|
|
from anonflow.bot.transport.types import RequestContext
|
|
|
from anonflow.interfaces import UserResponsesPort
|
|
|
|
|
|
+from .utils import extract_message
|
|
|
+
|
|
|
|
|
|
class UserThrottlingMiddleware(BaseMiddleware):
|
|
|
def __init__(
|
|
|
self,
|
|
|
responses_port: UserResponsesPort,
|
|
|
delay: float,
|
|
|
- allowed_chat_ids: Optional[Iterable[ChatIdUnion]] = None,
|
|
|
+ ignored_chat_ids: Optional[Iterable[Optional[ChatIdUnion]]] = None,
|
|
|
+ ignored_commands: Optional[Iterable[str]] = None
|
|
|
):
|
|
|
super().__init__()
|
|
|
|
|
|
self._responses_port = responses_port
|
|
|
self._delay = delay
|
|
|
- self._allowed_chat_ids = allowed_chat_ids
|
|
|
+ self._ignored_chat_ids = tuple(ignored_chat_ids or ())
|
|
|
+ self._ignored_commands = tuple(ignored_commands or ())
|
|
|
|
|
|
self._user_times: Dict[int, float] = {}
|
|
|
self._user_locks: Dict[int, asyncio.Lock] = {}
|
|
|
@@ -28,13 +33,14 @@ class UserThrottlingMiddleware(BaseMiddleware):
|
|
|
self._lock = asyncio.Lock()
|
|
|
|
|
|
async def __call__(self, handler, event, data):
|
|
|
- message = getattr(event, "message", None)
|
|
|
- if isinstance(message, Message) and (
|
|
|
- self._allowed_chat_ids is not None
|
|
|
- and message.chat.id not in self._allowed_chat_ids
|
|
|
+ message = extract_message(event)
|
|
|
+ if (
|
|
|
+ isinstance(message, Message)
|
|
|
+ and message.chat.id not in self._ignored_chat_ids
|
|
|
+ and message.chat.type == ChatType.PRIVATE
|
|
|
):
|
|
|
text = message.text or message.caption or ""
|
|
|
- if not text.startswith("/"):
|
|
|
+ if not text.startswith(self._ignored_commands):
|
|
|
async with self._lock:
|
|
|
user_lock = self._user_locks.setdefault(
|
|
|
message.chat.id, asyncio.Lock()
|