executor.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. import asyncio
  2. import logging
  3. import textwrap
  4. from typing import AsyncGenerator, Optional, Literal
  5. from .events import Event, ModerationDecisionEvent, ModerationStartedEvent
  6. from .planner import ModerationPlanner
  7. class ModerationExecutor:
  8. def __init__(self, moderation_planner: ModerationPlanner):
  9. self._logger = logging.getLogger(__name__)
  10. self._moderation_planner = moderation_planner
  11. self._moderation_planner.set_functions(self.moderation_decision)
  12. def moderation_decision(self, status: Literal["approve", "reject"], reason: str):
  13. moderation_map = {"approve": True, "reject": False}
  14. return ModerationDecisionEvent(
  15. is_approved=moderation_map.get(status.lower(), False), reason=reason
  16. )
  17. moderation_decision.description = textwrap.dedent( # type: ignore
  18. """
  19. Processes a message with a moderation decision by status and reason.
  20. This function must be called whenever there is no exact user request or no other available function
  21. matching the user's intent. Status must be either "approve if the message is allowed, or "reject" if it should be blocked.
  22. """
  23. ).strip()
  24. async def process(
  25. self, text: Optional[str] = None, image: Optional[str] = None
  26. ) -> AsyncGenerator[Event, None]:
  27. yield ModerationStartedEvent()
  28. functions = await self._moderation_planner.plan(text, image)
  29. function_names = self._moderation_planner.get_function_names()
  30. for function in functions:
  31. func_name = function.get("name", "")
  32. func_args = function.get("args", {})
  33. method = getattr(self, func_name, None)
  34. if method is None or func_name not in function_names:
  35. self._logger.warning("Function %s not found, skipping.", func_name)
  36. continue
  37. self._logger.info("Executing %s.", func_name)
  38. try:
  39. if asyncio.iscoroutinefunction(method):
  40. yield await method(**func_args)
  41. else:
  42. yield await asyncio.to_thread(method, **func_args)
  43. except Exception:
  44. self._logger.exception("Failed to execute %s.", func_name)