ソースを参照

Merge pull request #23 from librellium/feature/services-system-actor

Feature/services system actor
Librellium 3 週間 前
コミット
f1b0c40af6

+ 1 - 0
anonflow/constants.py

@@ -0,0 +1 @@
+SYSTEM_USER_ID = -1

+ 9 - 0
anonflow/database/database.py

@@ -1,3 +1,6 @@
+from contextlib import asynccontextmanager
+from typing import AsyncGenerator
+
 from sqlalchemy.engine import URL
 from sqlalchemy.engine import URL
 from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
 from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
 from sqlalchemy.orm import sessionmaker
 from sqlalchemy.orm import sessionmaker
@@ -14,6 +17,12 @@ class Database:
             self._engine, expire_on_commit=False, class_=AsyncSession # type: ignore
             self._engine, expire_on_commit=False, class_=AsyncSession # type: ignore
         )
         )
 
 
+    @asynccontextmanager
+    async def begin_session(self) -> AsyncGenerator[AsyncSession, None]:
+        async with self._session_maker() as session: # type: ignore
+            async with session.begin():
+                yield session
+
     async def close(self):
     async def close(self):
         await self._engine.dispose()
         await self._engine.dispose()
 
 

+ 2 - 0
anonflow/database/orm.py

@@ -48,6 +48,8 @@ class Moderator(Base):
     can_manage_bans = Column(Boolean, nullable=False, default=False)
     can_manage_bans = Column(Boolean, nullable=False, default=False)
     can_manage_moderators = Column(Boolean, nullable=False, default=False)
     can_manage_moderators = Column(Boolean, nullable=False, default=False)
 
 
+    is_root = Column(Boolean, nullable=False, default=False)
+
     user = relationship("User", back_populates="moderator")
     user = relationship("User", back_populates="moderator")
 
 
 class User(Base):
 class User(Base):

+ 16 - 18
anonflow/database/repositories/ban.py

@@ -6,12 +6,11 @@ from anonflow.database.orm import Ban
 
 
 class BanRepository:
 class BanRepository:
     async def ban(self, session: AsyncSession, actor_user_id: int, user_id: int):
     async def ban(self, session: AsyncSession, actor_user_id: int, user_id: int):
-        async with session.begin():
-            ban = Ban(
-                user_id=user_id,
-                banned_by=actor_user_id
-            )
-            session.add(ban)
+        ban = Ban(
+            user_id=user_id,
+            banned_by=actor_user_id
+        )
+        session.add(ban)
 
 
     async def is_banned(self, session: AsyncSession, user_id: int):
     async def is_banned(self, session: AsyncSession, user_id: int):
         result = await session.execute(
         result = await session.execute(
@@ -25,16 +24,15 @@ class BanRepository:
         return bool(result.scalar_one_or_none())
         return bool(result.scalar_one_or_none())
 
 
     async def unban(self, session: AsyncSession, actor_user_id: int, user_id: int):
     async def unban(self, session: AsyncSession, actor_user_id: int, user_id: int):
-        async with session.begin():
-            await session.execute(
-                update(Ban)
-                .where(
-                    Ban.user_id == user_id,
-                    Ban.is_active.is_(True)
-                )
-                .values(
-                    is_active=False,
-                    unbanned_at=func.now(),
-                    unbanned_by=actor_user_id
-                )
+        await session.execute(
+            update(Ban)
+            .where(
+                Ban.user_id == user_id,
+                Ban.is_active.is_(True)
+            )
+            .values(
+                is_active=False,
+                unbanned_at=func.now(),
+                unbanned_by=actor_user_id
             )
             )
+        )

+ 11 - 14
anonflow/database/repositories/base.py

@@ -13,9 +13,8 @@ class BaseRepository:
         )
         )
 
 
     async def _add(self, session: AsyncSession, model_args: Dict[str, Any]):
     async def _add(self, session: AsyncSession, model_args: Dict[str, Any]):
-        async with session.begin():
-            obj = self.model(**model_args)
-            session.add(obj)
+        obj = self.model(**model_args)
+        session.add(obj)
 
 
     async def _get(self, session: AsyncSession, filters: Dict[str, Any], options: List[Any] = []):
     async def _get(self, session: AsyncSession, filters: Dict[str, Any], options: List[Any] = []):
         result = await session.execute(
         result = await session.execute(
@@ -35,19 +34,17 @@ class BaseRepository:
         return bool(result.scalar_one_or_none())
         return bool(result.scalar_one_or_none())
 
 
     async def _remove(self, session: AsyncSession, filters: Dict[str, Any]):
     async def _remove(self, session: AsyncSession, filters: Dict[str, Any]):
-        async with session.begin():
-            await session.execute(
-                delete(self.model)
-                .filter_by(**filters)
-            )
+        await session.execute(
+            delete(self.model)
+            .filter_by(**filters)
+        )
 
 
     async def _update(self, session: AsyncSession, filters: Dict[str, Any], fields: Dict[str, Any]):
     async def _update(self, session: AsyncSession, filters: Dict[str, Any], fields: Dict[str, Any]):
         if not fields:
         if not fields:
             return
             return
 
 
-        async with session.begin():
-            await session.execute(
-                update(self.model)
-                .filter_by(**filters)
-                .values(**fields)
-            )
+        await session.execute(
+            update(self.model)
+            .filter_by(**filters)
+            .values(**fields)
+        )

+ 2 - 47
anonflow/database/repositories/moderator.py

@@ -1,4 +1,3 @@
-from dataclasses import dataclass
 from typing import Optional
 from typing import Optional
 
 
 from sqlalchemy.ext.asyncio import AsyncSession
 from sqlalchemy.ext.asyncio import AsyncSession
@@ -9,12 +8,6 @@ from anonflow.database.orm import Moderator
 from .base import BaseRepository
 from .base import BaseRepository
 
 
 
 
-@dataclass
-class ModeratorPermissions:
-    can_approve: bool
-    can_ban: bool
-    can_manage_moderators: bool
-
 class ModeratorRepository(BaseRepository):
 class ModeratorRepository(BaseRepository):
     model = Moderator
     model = Moderator
 
 
@@ -22,18 +15,13 @@ class ModeratorRepository(BaseRepository):
         self,
         self,
         session: AsyncSession,
         session: AsyncSession,
         user_id: int,
         user_id: int,
-        *,
-        can_approve_posts: bool = True,
-        can_manage_bans: bool = False,
-        can_manage_moderators: bool = False
+        **fields
     ):
     ):
         await super()._add(
         await super()._add(
             session,
             session,
             model_args={
             model_args={
                 "user_id": user_id,
                 "user_id": user_id,
-                "can_approve_posts": can_approve_posts,
-                "can_manage_bans": can_manage_bans,
-                "can_manage_moderators": can_manage_moderators
+                **fields
             }
             }
         )
         )
 
 
@@ -46,15 +34,6 @@ class ModeratorRepository(BaseRepository):
             ]
             ]
         )
         )
 
 
-    async def get_permissions(self, session: AsyncSession, user_id: int):
-        result = await self.get(session, user_id)
-        if result:
-            return ModeratorPermissions(
-                result.can_approve_posts.value,
-                result.can_manage_bans.value,
-                result.can_manage_moderators.value
-            )
-
     async def has(self, session: AsyncSession, user_id: int):
     async def has(self, session: AsyncSession, user_id: int):
         return await super()._has(
         return await super()._has(
             session,
             session,
@@ -73,27 +52,3 @@ class ModeratorRepository(BaseRepository):
             filters={"user_id": user_id},
             filters={"user_id": user_id},
             fields=fields
             fields=fields
         )
         )
-
-    async def update_permissions(
-        self,
-        session: AsyncSession,
-        user_id: int,
-        *,
-        can_approve_posts: Optional[bool] = None,
-        can_manage_bans: Optional[bool] = None,
-        can_manage_moderators: Optional[bool] = None
-    ):
-        to_update = {}
-        for key, value in (
-            ("can_approve_posts", can_approve_posts),
-            ("can_manage_bans", can_manage_bans),
-            ("can_manage_moderators", can_manage_moderators),
-        ):
-            if value is not None:
-                to_update[key] = value
-
-        await self.update(
-            session,
-            user_id,
-            **to_update
-        )

+ 2 - 4
anonflow/services/__init__.py

@@ -1,13 +1,11 @@
-from .accounts.exceptions import ForbiddenError
 from .accounts.moderator import ModeratorService
 from .accounts.moderator import ModeratorService
 from .accounts.user import UserService
 from .accounts.user import UserService
 from .transport.delivery import DeliveryService
 from .transport.delivery import DeliveryService
 from .transport.router import MessageRouter
 from .transport.router import MessageRouter
 
 
 __all__ = [
 __all__ = [
+    "ModeratorService",
+    "UserService",
     "DeliveryService",
     "DeliveryService",
     "MessageRouter",
     "MessageRouter",
-    "ForbiddenError",
-    "ModeratorService",
-    "UserService"
 ]
 ]

+ 7 - 1
anonflow/services/accounts/__init__.py

@@ -1,4 +1,10 @@
 from .moderator import ModeratorService
 from .moderator import ModeratorService
+from .moderator.exceptions import ModeratorPermissionError, SelfActionError
 from .user import UserService
 from .user import UserService
 
 
-__all__ = ["ModeratorService", "UserService"]
+__all__ = [
+    "ModeratorService",
+    "ModeratorPermissionError",
+    "SelfActionError",
+    "UserService"
+]

+ 0 - 1
anonflow/services/accounts/exceptions.py

@@ -1 +0,0 @@
-class ForbiddenError(PermissionError): ...

+ 0 - 132
anonflow/services/accounts/moderator.py

@@ -1,132 +0,0 @@
-import logging
-from typing import Optional
-
-from sqlalchemy.exc import IntegrityError
-from sqlalchemy.ext.asyncio import AsyncSession
-
-from anonflow.database import BanRepository, Database, ModeratorRepository
-
-from .exceptions import ForbiddenError
-
-
-class ModeratorService:
-    def __init__(
-        self,
-        database: Database,
-        ban_repository: BanRepository,
-        moderator_repository: ModeratorRepository
-    ):
-        self._logger = logging.getLogger(__name__)
-
-        self._database = database
-        self._ban_repository = ban_repository
-        self._moderator_repository = moderator_repository
-
-    async def add(self, actor_user_id: int, user_id: int):
-        try:
-            async with self._database.get_session() as session:
-                if await self._can_manage_moderators(session, actor_user_id):
-                    await self._moderator_repository.add(session, user_id)
-                else:
-                    raise ForbiddenError()
-        except IntegrityError:
-            self._logger.warning("Failed to add moderator user_id=%s", user_id)
-
-    async def ban(self, actor_user_id: int, user_id: int):
-        async with self._database.get_session() as session:
-            if await self._can_manage_bans(session, actor_user_id):
-                await self._ban_repository.ban(session, actor_user_id, user_id)
-            else:
-                raise ForbiddenError()
-
-    async def _get_permission(self, session: AsyncSession, actor_user_id: int, name: str) -> bool:
-        moderator = await self._moderator_repository.get(session, actor_user_id)
-        return getattr(getattr(moderator, name, None), "value", False)
-
-    async def _can_approve_posts(self, session: AsyncSession, actor_user_id: int):
-        return await self._get_permission(session, actor_user_id, "can_approve_posts")
-
-    async def can_approve_posts(self, actor_user_id: int):
-        async with self._database.get_session() as session:
-            return await self._can_approve_posts(session, actor_user_id)
-
-    async def _can_manage_bans(self, session: AsyncSession, actor_user_id: int):
-        return await self._get_permission(session, actor_user_id, "can_manage_bans")
-
-    async def can_manage_bans(self, actor_user_id: int):
-        async with self._database.get_session() as session:
-            return await self._can_manage_bans(session, actor_user_id)
-
-    async def _can_manage_moderators(self, session: AsyncSession, actor_user_id: int):
-        return await self._get_permission(session, actor_user_id, "can_manage_moderators")
-
-    async def can_manage_moderators(self, actor_user_id: int):
-        async with self._database.get_session() as session:
-            return await self._can_manage_moderators(session, actor_user_id)
-
-    async def get(self, user_id: int):
-        async with self._database.get_session() as session:
-            return await self._moderator_repository.get(session, user_id)
-
-    async def get_permissions(self, user_id: int):
-        async with self._database.get_session() as session:
-            return await self._moderator_repository.get_permissions(session, user_id)
-
-    async def has(self, user_id: int):
-        async with self._database.get_session() as session:
-            return await self._moderator_repository.has(session, user_id)
-
-    async def is_banned(self, user_id: int):
-        async with self._database.get_session() as session:
-            return await self._ban_repository.is_banned(session, user_id)
-
-    async def remove(self, actor_user_id: int, user_id: int):
-        try:
-            async with self._database.get_session() as session:
-                if await self._can_manage_moderators(session, actor_user_id):
-                    await self._moderator_repository.remove(session, user_id)
-                else:
-                    raise ForbiddenError()
-        except IntegrityError:
-            self._logger.warning("Failed to remove moderator user_id=%s", user_id)
-
-    async def unban(self, actor_user_id: int, user_id: int):
-        async with self._database.get_session() as session:
-            if await self._can_manage_bans(session, actor_user_id):
-                await self._ban_repository.unban(session, actor_user_id, user_id)
-            else:
-                raise ForbiddenError()
-
-    async def update(self, actor_user_id: int, user_id: int, **fields):
-        try:
-            async with self._database.get_session() as session:
-                if await self._can_manage_moderators(session, actor_user_id):
-                    await self._moderator_repository.update(session, user_id, **fields)
-                else:
-                    raise ForbiddenError()
-        except IntegrityError:
-            self._logger.warning("Failed to update moderator user_id=%s", user_id)
-
-    async def update_permissions(
-        self,
-        actor_user_id: int,
-        user_id: int,
-        *,
-        can_approve_posts: Optional[bool] = None,
-        can_manage_bans: Optional[bool] = None,
-        can_manage_moderators: Optional[bool] = None
-    ):
-        try:
-            async with self._database.get_session() as session:
-                if await self._can_manage_moderators(session, actor_user_id):
-                    await self._moderator_repository.update_permissions(
-                        session,
-                        user_id,
-                        can_approve_posts=can_approve_posts,
-                        can_manage_bans=can_manage_bans,
-                        can_manage_moderators=can_manage_moderators
-                    )
-                else:
-                    raise ForbiddenError()
-        except IntegrityError:
-            self._logger.warning("Failed to update moderator user_id=%s", user_id)

+ 3 - 0
anonflow/services/accounts/moderator/__init__.py

@@ -0,0 +1,3 @@
+from .service import ModeratorService
+
+__all__ = ["ModeratorService"]

+ 3 - 0
anonflow/services/accounts/moderator/exceptions.py

@@ -0,0 +1,3 @@
+class ModeratorPermissionError(PermissionError): ...
+
+class SelfActionError(ModeratorPermissionError): ...

+ 17 - 0
anonflow/services/accounts/moderator/permissions.py

@@ -0,0 +1,17 @@
+from dataclasses import dataclass, asdict
+from enum import Enum
+
+
+@dataclass(frozen=True)
+class ModeratorPermissions:
+    can_approve_posts: bool = False
+    can_manage_bans: bool = False
+    can_manage_moderators: bool = False
+
+    def to_dict(self):
+        return asdict(self)
+
+class ModeratorPermission(str, Enum):
+    APPROVE_POSTS = "can_approve_posts"
+    MANAGE_BANS = "can_manage_bans"
+    MANAGE_MODERATORS = "can_manage_moderators"

+ 156 - 0
anonflow/services/accounts/moderator/service.py

@@ -0,0 +1,156 @@
+import logging
+
+from sqlalchemy.exc import IntegrityError
+from sqlalchemy.ext.asyncio import AsyncSession
+
+from anonflow.constants import SYSTEM_USER_ID
+from anonflow.database import BanRepository, Database, ModeratorRepository
+
+from .exceptions import ModeratorPermissionError, SelfActionError
+from .permissions import ModeratorPermission, ModeratorPermissions
+
+
+class ModeratorService:
+    def __init__(
+        self,
+        database: Database,
+        ban_repository: BanRepository,
+        moderator_repository: ModeratorRepository
+    ):
+        self._logger = logging.getLogger(__name__)
+
+        self._database = database
+        self._ban_repository = ban_repository
+        self._moderator_repository = moderator_repository
+
+    @staticmethod
+    def _assert_not_self(actor_user_id: int, user_id: int):
+        if actor_user_id == user_id:
+            raise SelfActionError(
+                f"Moderator user_id={actor_user_id} cannot perform this action on themselves (target user_id={user_id})."
+            )
+
+    async def add(self, actor_user_id: int, user_id: int):
+        try:
+            async with self._database.begin_session() as session:
+                if await self._can(session, actor_user_id, ModeratorPermission.MANAGE_MODERATORS):
+                    self._assert_not_self(actor_user_id, user_id)
+                    await self._moderator_repository.add(session, user_id)
+                else:
+                    raise ModeratorPermissionError(
+                        f"Moderator user_id={actor_user_id} does not have permission to perform 'add'."
+                    )
+        except IntegrityError:
+            self._logger.warning("Failed to add moderator user_id=%s", user_id)
+
+    async def ban(self, actor_user_id: int, user_id: int):
+        async with self._database.begin_session() as session:
+            if await self._can(session, actor_user_id, ModeratorPermission.MANAGE_BANS):
+                self._assert_not_self(actor_user_id, user_id)
+                await self._ban_repository.ban(session, actor_user_id, user_id)
+            else:
+                raise ModeratorPermissionError(
+                    f"Moderator user_id={actor_user_id} does not have permission to perform 'ban'."
+                )
+
+    async def _can(self, session: AsyncSession, actor_user_id: int, permission: ModeratorPermission) -> bool:
+        moderator = await self._moderator_repository.get(session, actor_user_id)
+        if moderator:
+            if moderator.is_root.value:
+                return True
+            return getattr(getattr(moderator, permission, None), "value", False)
+
+        return False
+
+    async def can(self, actor_user_id: int, permission: ModeratorPermission):
+        async with self._database.get_session() as session:
+            return self._can(session, actor_user_id, permission)
+
+    async def get(self, user_id: int):
+        async with self._database.get_session() as session:
+            return await self._moderator_repository.get(session, user_id)
+
+    async def get_permissions(self, user_id: int):
+        async with self._database.get_session() as session:
+            result = await self._moderator_repository.get(session, user_id)
+            if not result:
+                return ModeratorPermissions()
+
+            return ModeratorPermissions(
+                **{
+                    key: value
+                    for key, value in result.__dict__.items()
+                    if key.startswith("can_")
+                }
+            )
+
+    async def has(self, user_id: int):
+        async with self._database.get_session() as session:
+            return await self._moderator_repository.has(session, user_id)
+
+    async def init(self):
+        async with self._database.begin_session() as session:
+            if not await self._moderator_repository.has(session, SYSTEM_USER_ID):
+                await self._moderator_repository.add(session, SYSTEM_USER_ID, is_root=True)
+
+    async def is_banned(self, user_id: int):
+        async with self._database.get_session() as session:
+            return await self._ban_repository.is_banned(session, user_id)
+
+    async def remove(self, actor_user_id: int, user_id: int):
+        try:
+            async with self._database.begin_session() as session:
+                if await self._can(session, actor_user_id, ModeratorPermission.MANAGE_MODERATORS):
+                    self._assert_not_self(actor_user_id, user_id)
+                    await self._moderator_repository.remove(session, user_id)
+                else:
+                    raise ModeratorPermissionError(
+                        f"Moderator user_id={actor_user_id} does not have permission to perform 'remove'."
+                    )
+        except IntegrityError:
+            self._logger.warning("Failed to remove moderator user_id=%s", user_id)
+
+    async def unban(self, actor_user_id: int, user_id: int):
+        async with self._database.begin_session() as session:
+            if await self._can(session, actor_user_id, ModeratorPermission.MANAGE_BANS):
+                self._assert_not_self(actor_user_id, user_id)
+                await self._ban_repository.unban(session, actor_user_id, user_id)
+            else:
+                raise ModeratorPermissionError(
+                    f"Moderator user_id={actor_user_id} does not have permission to perform 'unban'."
+                )
+
+    async def update(self, actor_user_id: int, user_id: int, **fields):
+        try:
+            async with self._database.begin_session() as session:
+                if await self._can(session, actor_user_id, ModeratorPermission.MANAGE_MODERATORS):
+                    self._assert_not_self(actor_user_id, user_id)
+                    await self._moderator_repository.update(session, user_id, **fields)
+                else:
+                    raise ModeratorPermissionError(
+                        f"Moderator user_id={actor_user_id} does not have permission to perform 'update'."
+                    )
+        except IntegrityError:
+            self._logger.warning("Failed to update moderator user_id=%s", user_id)
+
+    async def update_permissions(
+        self,
+        actor_user_id: int,
+        user_id: int,
+        permissions: ModeratorPermissions
+    ):
+        try:
+            async with self._database.begin_session() as session:
+                if await self._can(session, actor_user_id, ModeratorPermission.MANAGE_MODERATORS):
+                    self._assert_not_self(actor_user_id, user_id)
+                    await self._moderator_repository.update(
+                        session,
+                        user_id,
+                        **permissions.to_dict()
+                    )
+                else:
+                    raise ModeratorPermissionError(
+                        f"Moderator user_id={actor_user_id} does not have permission to perform 'update_permissions'."
+                    )
+        except IntegrityError:
+            self._logger.warning("Failed to update moderator user_id=%s", user_id)

+ 3 - 3
anonflow/services/accounts/user.py

@@ -14,7 +14,7 @@ class UserService:
 
 
     async def add(self, user_id: int):
     async def add(self, user_id: int):
         try:
         try:
-            async with self._database.get_session() as session:
+            async with self._database.begin_session() as session:
                 await self._user_repository.add(session, user_id)
                 await self._user_repository.add(session, user_id)
         except IntegrityError:
         except IntegrityError:
             self._logger.warning("Failed to add user user_id=%s", user_id)
             self._logger.warning("Failed to add user user_id=%s", user_id)
@@ -29,14 +29,14 @@ class UserService:
 
 
     async def remove(self, user_id: int):
     async def remove(self, user_id: int):
         try:
         try:
-            async with self._database.get_session() as session:
+            async with self._database.begin_session() as session:
                 await self._user_repository.remove(session, user_id)
                 await self._user_repository.remove(session, user_id)
         except IntegrityError:
         except IntegrityError:
             self._logger.warning("Failed to remove user user_id=%s", user_id)
             self._logger.warning("Failed to remove user user_id=%s", user_id)
 
 
     async def update(self, user_id: int, **fields):
     async def update(self, user_id: int, **fields):
         try:
         try:
-            async with self._database.get_session() as session:
+            async with self._database.begin_session() as session:
                 await self._user_repository.update(session, user_id, **fields)
                 await self._user_repository.update(session, user_id, **fields)
         except IntegrityError:
         except IntegrityError:
             self._logger.warning("Failed to update user user_id=%s", user_id)
             self._logger.warning("Failed to update user user_id=%s", user_id)