config.py 3.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. from pathlib import Path
  2. from string import Template
  3. import yaml
  4. from dotenv import dotenv_values
  5. from pydantic import BaseModel, Field, SecretStr
  6. from sqlalchemy.engine import URL
  7. from .models import App, Bot, Database, Logging, Moderation, OpenAI
  8. class Config(BaseModel):
  9. app: App = Field(default_factory=App)
  10. bot: Bot = Field(default_factory=Bot)
  11. database: Database = Field(default_factory=Database)
  12. openai: OpenAI = Field(default_factory=OpenAI)
  13. moderation: Moderation = Field(default_factory=Moderation)
  14. logging: Logging = Field(default_factory=Logging)
  15. def get_database_url(self):
  16. password = None
  17. if self.database.password:
  18. password = (
  19. self.database.password.get_secret_value()
  20. if isinstance(self.database.password, SecretStr)
  21. else self.database.password
  22. )
  23. return URL.create(
  24. drivername=self.database.backend,
  25. username=self.database.username,
  26. password=password,
  27. host=self.database.host,
  28. port=self.database.port,
  29. database=str(self.database.name_or_path),
  30. )
  31. def get_migrations_url(self):
  32. password = None
  33. if self.database.password:
  34. password = (
  35. self.database.password.get_secret_value()
  36. if isinstance(self.database.password, SecretStr)
  37. else self.database.password
  38. )
  39. return URL.create(
  40. drivername=self.database.migrations.backend,
  41. username=self.database.username,
  42. password=password,
  43. host=self.database.host,
  44. port=self.database.port,
  45. database=str(self.database.name_or_path),
  46. )
  47. @classmethod
  48. def load(cls, filepath: Path):
  49. filepath = Path(filepath)
  50. filepath.parent.mkdir(parents=True, exist_ok=True)
  51. if filepath.exists():
  52. with filepath.open(encoding="utf-8") as f:
  53. template = Template(f.read())
  54. rendered = template.safe_substitute(dotenv_values())
  55. data = yaml.safe_load(rendered) or {}
  56. return cls(**data) # type: ignore
  57. return cls()
  58. @classmethod
  59. def _prepare_for_save(cls, obj):
  60. if isinstance(obj, SecretStr):
  61. return obj.get_secret_value()
  62. elif isinstance(obj, dict):
  63. return {key: cls._prepare_for_save(value) for key, value in obj.items()}
  64. elif isinstance(obj, list):
  65. return [cls._prepare_for_save(value) for value in obj]
  66. return obj
  67. def save(self, filepath: Path):
  68. filepath.parent.mkdir(parents=True, exist_ok=True)
  69. with filepath.open("w", encoding="utf-8") as config_file:
  70. yaml.dump(
  71. self._prepare_for_save(self.model_dump()),
  72. config_file,
  73. width=float("inf"),
  74. sort_keys=False,
  75. default_flow_style=False,
  76. )