throttling.py 2.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. import asyncio
  2. import time
  3. from typing import Dict, Iterable, Optional
  4. from aiogram import BaseMiddleware
  5. from aiogram.enums import ChatType
  6. from aiogram.types import ChatIdUnion, Message
  7. from anonflow.bot.transport.types import RequestContext
  8. from anonflow.interfaces import UserResponsesPort
  9. from .utils import extract_message
  10. class UserThrottlingMiddleware(BaseMiddleware):
  11. def __init__(
  12. self,
  13. responses_port: UserResponsesPort,
  14. delay: float,
  15. ignored_chat_ids: Optional[Iterable[Optional[ChatIdUnion]]] = None,
  16. ignored_commands: Optional[Iterable[str]] = None
  17. ):
  18. super().__init__()
  19. self._responses_port = responses_port
  20. self._delay = delay
  21. self._ignored_chat_ids = tuple(ignored_chat_ids or ())
  22. self._ignored_commands = tuple(ignored_commands or ())
  23. self._user_times: Dict[int, float] = {}
  24. self._user_locks: Dict[int, asyncio.Lock] = {}
  25. self._lock = asyncio.Lock()
  26. async def __call__(self, handler, event, data):
  27. message = extract_message(event)
  28. if (
  29. isinstance(message, Message)
  30. and message.chat.id not in self._ignored_chat_ids
  31. and message.chat.type == ChatType.PRIVATE
  32. ):
  33. text = message.text or message.caption or ""
  34. if not text.startswith(self._ignored_commands):
  35. async with self._lock:
  36. user_lock = self._user_locks.setdefault(
  37. message.chat.id, asyncio.Lock()
  38. )
  39. if user_lock.locked():
  40. start_time = self._user_times.get(message.chat.id) or 0
  41. current_time = time.monotonic()
  42. await self._responses_port.user_throttled(
  43. RequestContext(message.chat.id, data["user_language"]),
  44. remaining_time=(
  45. round(self._delay - (current_time - start_time))
  46. if start_time
  47. else 0
  48. ),
  49. )
  50. return
  51. async with user_lock:
  52. start_time = time.monotonic()
  53. self._user_times[message.chat.id] = start_time
  54. result = await handler(event, data)
  55. elapsed_time = time.monotonic() - start_time
  56. await asyncio.sleep(max(0, self._delay - elapsed_time))
  57. async with self._lock:
  58. self._user_locks.pop(message.chat.id)
  59. return result
  60. return await handler(event, data)