Просмотр исходного кода

Refactor ModerationPlanner to improve code readability

Librellium 3 месяцев назад
Родитель
Сommit
f9c87e9e9b
1 измененных файлов с 81 добавлено и 50 удалено
  1. 81 50
      anonflow/moderation/planner.py

+ 81 - 50
anonflow/moderation/planner.py

@@ -2,7 +2,7 @@ import inspect
 import json
 import logging
 from json import JSONDecodeError
-from typing import Dict, List, Union
+from typing import Any, Dict, List, Optional, Union
 
 from openai import APIResponseValidationError, AsyncOpenAI
 
@@ -12,20 +12,20 @@ from .rule_manager import RuleManager
 
 
 class ModerationPlanner:
-    def __init__(self,
-                 config: Config,
-                 rule_manager: RuleManager):
+    def __init__(self, config: Config, rule_manager: RuleManager):
         self._logger = logging.getLogger(__name__)
 
         self.config = config
         self.rule_manager = rule_manager
 
-        self._client = AsyncOpenAI(api_key=self.config.openai.api_key.get_secret_value(),
-                                   timeout=self.config.openai.timeout,
-                                   max_retries=self.config.openai.max_retries)
+        self._client = AsyncOpenAI(
+            api_key=self.config.openai.api_key.get_secret_value(),
+            timeout=self.config.openai.timeout,
+            max_retries=self.config.openai.max_retries,
+        )
         self.moderation = self.config.moderation.enabled
 
-        self._functions: List[Dict[str]] = []
+        self._functions: List[Dict[str, Any]] = []
 
     def set_functions(self, *functions):
         if not functions:
@@ -34,41 +34,65 @@ class ModerationPlanner:
         self._functions.clear()
         for func in functions:
             sig = inspect.signature(func)
-            args = {name: str(param.annotation) if param.annotation != inspect._empty else "str"
-                    for name, param in sig.parameters.items()}
+            args = {
+                name: (
+                    str(param.annotation)
+                    if param.annotation != inspect._empty
+                    else "str"
+                )
+                for name, param in sig.parameters.items()
+            }
 
-            self._functions.append({
-                "name": func.__name__,
-                "args": args,
-                "description": func.__doc__ or ""
-            })
+            self._functions.append(
+                {"name": func.__name__, "args": args, "description": func.__doc__ or ""}
+            )
 
         function_names = self.get_function_names()
 
         if "moderation_decision" not in function_names:
-            self._logger.warning("Critical function 'moderation_decision' not found. Running the bot in this mode is NOT recommended!")
+            self._logger.warning(
+                "Critical function 'moderation_decision' not found. Running the bot in this mode is NOT recommended!"
+            )
+
+        self._logger.info(
+            f"Functions added: {', '.join(function_names)}. Total={len(self._functions)}"
+        )
 
-        self._logger.info(f"Functions added: {', '.join(function_names)}. Total={len(self._functions)}")
+    def get_function_names(self) -> List[str]:
+        return [name for f in self._functions if (name := f.get("name"))]
 
-    def get_function_names(self) -> Union[List[str], None]:
-        return [f.get("name") for f in self._functions]
+    async def plan(
+        self, text: Optional[str] = None, image: Optional[str] = None
+    ) -> List[Dict[str, Union[list, str]]]:
+        if not self.moderation:
+            return [
+                {
+                    "name": "moderation_decision",
+                    "args": ["APPROVE", "Модерация выключена."]
+                }
+            ]
 
-    async def plan(self, message_text: str) -> List[Dict[str, Union[list, str]]]:
         if "omni" in self.config.moderation.types:
-            moderation = await self._client.moderations.create(
-                model="omni-moderation-latest",
-                input=[
-                    {
-                        "type": "text",
-                        "text": message_text
-                    }
-                ]
-            )
+            content = []
+            if text:
+                content.append({"type": "text", "text": text})
+            if image:
+                content.append({"type": "image_url", "image_url": {"url": image}})
+
+            if content:
+                moderation = await self._client.moderations.create(
+                    model="omni-moderation-latest", input=content
+                )
 
-            if moderation.results[0].flagged:
-                return [{"name": "moderation_decision", "args": ["REJECT", "Сообщение заблокировано автомодератором"]}]
+                if moderation.results[0].flagged:
+                    return [
+                        {
+                            "name": "moderation_decision",
+                            "args": ["REJECT", "Сообщение заблокировано автомодератором."]
+                        }
+                    ]
 
-        if "gpt" in self.config.moderation.types:
+        if "gpt" in self.config.moderation.types and text:
             funcs = self._functions
             funcs_prompt = "\n".join(
                 f"- {func['name']}({', '.join(f'{arg}: {ann}' for arg, ann in (func.get('args') or {}).items())})"
@@ -86,24 +110,24 @@ class ModerationPlanner:
                         {
                             "role": "system",
                             "content": "Respond strictly with a JSON array in the following format:\n"
-                                    "`[{\"name\": ..., \"args\": [...]} , ...]`\n"
-                                    "`name` - the function name, `args` - an ordered list of arguments.\n"
-                                    "Output only a valid JSON. Choose functions based on the user's request and the function descriptions.\n"
-                                    "You are allowed to call multiple functions, listing them in order in the output.\n\n"
-                                    "**IMPORTANT:**\n"
-                                    "- Each function must include **all and only the required arguments** specified in its description.\n"
-                                    "- Do not invent additional arguments.\n"
-                                    "- Do not omit required arguments.\n"
-                                    "- `args` must be in the order specified in the function description.\n\n"
-                                    "Available functions:\n"
-                                    f"{funcs_prompt}"
+                            '`[{"name": ..., "args": [...]} , ...]`\n'
+                            "`name` - the function name, `args` - an ordered list of arguments.\n"
+                            "Output only a valid JSON. Choose functions based on the user's request and the function descriptions.\n"
+                            "You are allowed to call multiple functions, listing them in order in the output.\n\n"
+                            "**IMPORTANT:**\n"
+                            "- Each function must include **all and only the required arguments** specified in its description.\n"
+                            "- Do not invent additional arguments.\n"
+                            "- Do not omit required arguments.\n"
+                            "- `args` must be in the order specified in the function description.\n\n"
+                            "Available functions:\n"
+                            f"{funcs_prompt}",
                         },
-                        *[{"role": "system", "content": rule} for rule in self.rule_manager.get_rules()],
-                        {
-                            "role": "user",
-                            "content": message_text
-                        }
-                    ]
+                        *[
+                            {"role": "system", "content": rule}
+                            for rule in self.rule_manager.get_rules()
+                        ],
+                        {"role": "user", "content": text},
+                    ],
                 )
 
                 try:
@@ -120,4 +144,11 @@ class ModerationPlanner:
             if not result:
                 raise APIResponseValidationError()
 
-            return result
+            return result
+
+        return [
+            {
+                "name": "moderation_decision",
+                "args": ["APPROVE", "Модераторы не сработали."],
+            }
+        ]