Sfoglia il codice sorgente

refactor(repositories): remove session.begin transaction management

Librellium 3 settimane fa
parent
commit
093877688f

+ 16 - 18
anonflow/database/repositories/ban.py

@@ -6,12 +6,11 @@ from anonflow.database.orm import Ban
 
 class BanRepository:
     async def ban(self, session: AsyncSession, actor_user_id: int, user_id: int):
-        async with session.begin():
-            ban = Ban(
-                user_id=user_id,
-                banned_by=actor_user_id
-            )
-            session.add(ban)
+        ban = Ban(
+            user_id=user_id,
+            banned_by=actor_user_id
+        )
+        session.add(ban)
 
     async def is_banned(self, session: AsyncSession, user_id: int):
         result = await session.execute(
@@ -25,16 +24,15 @@ class BanRepository:
         return bool(result.scalar_one_or_none())
 
     async def unban(self, session: AsyncSession, actor_user_id: int, user_id: int):
-        async with session.begin():
-            await session.execute(
-                update(Ban)
-                .where(
-                    Ban.user_id == user_id,
-                    Ban.is_active.is_(True)
-                )
-                .values(
-                    is_active=False,
-                    unbanned_at=func.now(),
-                    unbanned_by=actor_user_id
-                )
+        await session.execute(
+            update(Ban)
+            .where(
+                Ban.user_id == user_id,
+                Ban.is_active.is_(True)
+            )
+            .values(
+                is_active=False,
+                unbanned_at=func.now(),
+                unbanned_by=actor_user_id
             )
+        )

+ 11 - 14
anonflow/database/repositories/base.py

@@ -13,9 +13,8 @@ class BaseRepository:
         )
 
     async def _add(self, session: AsyncSession, model_args: Dict[str, Any]):
-        async with session.begin():
-            obj = self.model(**model_args)
-            session.add(obj)
+        obj = self.model(**model_args)
+        session.add(obj)
 
     async def _get(self, session: AsyncSession, filters: Dict[str, Any], options: List[Any] = []):
         result = await session.execute(
@@ -35,19 +34,17 @@ class BaseRepository:
         return bool(result.scalar_one_or_none())
 
     async def _remove(self, session: AsyncSession, filters: Dict[str, Any]):
-        async with session.begin():
-            await session.execute(
-                delete(self.model)
-                .filter_by(**filters)
-            )
+        await session.execute(
+            delete(self.model)
+            .filter_by(**filters)
+        )
 
     async def _update(self, session: AsyncSession, filters: Dict[str, Any], fields: Dict[str, Any]):
         if not fields:
             return
 
-        async with session.begin():
-            await session.execute(
-                update(self.model)
-                .filter_by(**filters)
-                .values(**fields)
-            )
+        await session.execute(
+            update(self.model)
+            .filter_by(**filters)
+            .values(**fields)
+        )

+ 2 - 47
anonflow/database/repositories/moderator.py

@@ -1,4 +1,3 @@
-from dataclasses import dataclass
 from typing import Optional
 
 from sqlalchemy.ext.asyncio import AsyncSession
@@ -9,12 +8,6 @@ from anonflow.database.orm import Moderator
 from .base import BaseRepository
 
 
-@dataclass
-class ModeratorPermissions:
-    can_approve: bool
-    can_ban: bool
-    can_manage_moderators: bool
-
 class ModeratorRepository(BaseRepository):
     model = Moderator
 
@@ -22,18 +15,13 @@ class ModeratorRepository(BaseRepository):
         self,
         session: AsyncSession,
         user_id: int,
-        *,
-        can_approve_posts: bool = True,
-        can_manage_bans: bool = False,
-        can_manage_moderators: bool = False
+        **fields
     ):
         await super()._add(
             session,
             model_args={
                 "user_id": user_id,
-                "can_approve_posts": can_approve_posts,
-                "can_manage_bans": can_manage_bans,
-                "can_manage_moderators": can_manage_moderators
+                **fields
             }
         )
 
@@ -46,15 +34,6 @@ class ModeratorRepository(BaseRepository):
             ]
         )
 
-    async def get_permissions(self, session: AsyncSession, user_id: int):
-        result = await self.get(session, user_id)
-        if result:
-            return ModeratorPermissions(
-                result.can_approve_posts.value,
-                result.can_manage_bans.value,
-                result.can_manage_moderators.value
-            )
-
     async def has(self, session: AsyncSession, user_id: int):
         return await super()._has(
             session,
@@ -73,27 +52,3 @@ class ModeratorRepository(BaseRepository):
             filters={"user_id": user_id},
             fields=fields
         )
-
-    async def update_permissions(
-        self,
-        session: AsyncSession,
-        user_id: int,
-        *,
-        can_approve_posts: Optional[bool] = None,
-        can_manage_bans: Optional[bool] = None,
-        can_manage_moderators: Optional[bool] = None
-    ):
-        to_update = {}
-        for key, value in (
-            ("can_approve_posts", can_approve_posts),
-            ("can_manage_bans", can_manage_bans),
-            ("can_manage_moderators", can_manage_moderators),
-        ):
-            if value is not None:
-                to_update[key] = value
-
-        await self.update(
-            session,
-            user_id,
-            **to_update
-        )