throttling.py 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. import asyncio
  2. import time
  3. from typing import Dict, Iterable, Optional
  4. from aiogram import BaseMiddleware
  5. from aiogram.types import ChatIdUnion, Message
  6. from anonflow.bot.transport.types import RequestContext
  7. from anonflow.interfaces import UserResponsesPort
  8. class UserThrottlingMiddleware(BaseMiddleware):
  9. def __init__(
  10. self,
  11. responses_port: UserResponsesPort,
  12. delay: float,
  13. allowed_chat_ids: Optional[Iterable[ChatIdUnion]] = None,
  14. ):
  15. super().__init__()
  16. self._responses_port = responses_port
  17. self._delay = delay
  18. self._allowed_chat_ids = allowed_chat_ids
  19. self._user_times: Dict[int, float] = {}
  20. self._user_locks: Dict[int, asyncio.Lock] = {}
  21. self._lock = asyncio.Lock()
  22. async def __call__(self, handler, event, data):
  23. message = getattr(event, "message", None)
  24. if isinstance(message, Message) and (
  25. self._allowed_chat_ids is not None
  26. and message.chat.id not in self._allowed_chat_ids
  27. ):
  28. text = message.text or message.caption or ""
  29. if not text.startswith("/"):
  30. async with self._lock:
  31. user_lock = self._user_locks.setdefault(
  32. message.chat.id, asyncio.Lock()
  33. )
  34. if user_lock.locked():
  35. start_time = self._user_times.get(message.chat.id) or 0
  36. current_time = time.monotonic()
  37. await self._responses_port.user_throttled(
  38. RequestContext(message.chat.id, data["user_language"]),
  39. remaining_time=(
  40. round(self._delay - (current_time - start_time))
  41. if start_time
  42. else 0
  43. ),
  44. )
  45. return
  46. async with user_lock:
  47. start_time = time.monotonic()
  48. self._user_times[message.chat.id] = start_time
  49. result = await handler(event, data)
  50. elapsed_time = time.monotonic() - start_time
  51. await asyncio.sleep(max(0, self._delay - elapsed_time))
  52. async with self._lock:
  53. self._user_locks.pop(message.chat.id)
  54. return result
  55. return await handler(event, data)