Skip to content
Snippets Groups Projects
Commit f4e8f649 authored by Tim Repke's avatar Tim Repke
Browse files

migrate pydantic

parent b253a4a4
No related branches found
No related tags found
1 merge request!53OpenAlex Search and pydantic v2 migration
......@@ -4,7 +4,9 @@ import json
import toml
import os
from pydantic import BaseSettings, BaseModel, PostgresDsn, AnyHttpUrl, EmailStr, validator
from pydantic_settings import SettingsConfigDict, BaseSettings
from pydantic.networks import PostgresDsn
from pydantic import field_validator, FieldValidationInfo, AnyHttpUrl, BaseModel, EmailStr
# For more information how BaseSettings work, check the documentation:
......@@ -32,7 +34,8 @@ class ServerConfig(BaseModel):
HEADER_TRUSTED_HOST: bool = False # set to true to allow hosts from any origin
CORS_ORIGINS: list[AnyHttpUrl] = [] # list of trusted hosts
@validator("CORS_ORIGINS", pre=True)
@field_validator('CORS_ORIGINS', mode='before')
@classmethod
def assemble_cors_origins(cls, v: str | list[str]) -> str | list[str]:
if isinstance(v, str) and not v.startswith('['):
return [i.strip() for i in v.split(',')]
......@@ -46,6 +49,7 @@ class ServerConfig(BaseModel):
class DatabaseConfig(BaseModel):
SCHEME: str = 'postgresql'
HOST: str = 'localhost' # host of the db server
PORT: int = 5432 # port of the db server
USER: str = 'nacsos' # username for the database
......@@ -54,16 +58,20 @@ class DatabaseConfig(BaseModel):
CONNECTION_STR: PostgresDsn | None = None
@validator('CONNECTION_STR', pre=True)
def build_connection_string(cls, v: str | None, values: dict[str, Any]) -> Any:
@field_validator('CONNECTION_STR', mode='before')
def build_connection_string(cls, v: str | None, info: FieldValidationInfo) -> str:
assert info.config is not None
if isinstance(v, str):
return v
return PostgresDsn.build(
scheme="postgresql",
user=values.get('USER'),
password=values.get('PASSWORD'),
host=values.get('HOST'),
path=f'/{values.get("DATABASE", "")}',
scheme=info.data.get('SCHEME', 'postgresql'),
username=info.data.get('USER'),
password=info.data.get('PASSWORD'),
host=info.data.get('HOST'),
port=info.data.get('PORT'),
path=f'/{info.data.get("DATABASE", "")}',
)
......@@ -77,15 +85,17 @@ class EmailConfig(BaseModel):
SENDER_NAME: str | None = 'NACSOS'
ENABLED: bool = False
@validator("ENABLED", pre=True)
def get_emails_enabled(cls, v: bool, values: dict[str, Any]) -> bool:
@field_validator('ENABLED', mode='before')
@classmethod
def get_emails_enabled(cls, v: str | None, info: FieldValidationInfo) -> bool:
assert info.config is not None
return bool(
values.get('SMTP_HOST')
and values.get('SMTP_PORT')
and values.get('SENDER_ADDRESS')
info.data.get('SMTP_HOST')
and info.data.get('SMTP_PORT')
and info.data.get('SENDER_ADDRESS')
)
TEST_USER: EmailStr = EmailStr('test@nacsos.eu')
TEST_USER: EmailStr = 'test@nacsos.eu'
class UsersConfig(BaseModel):
......@@ -113,11 +123,14 @@ class Settings(BaseSettings):
LOG_CONF_FILE: str = 'config/logging.conf'
LOGGING_CONF: dict[str, Any] | None = None
@validator('LOGGING_CONF', pre=True)
def read_logging_config(cls, v: dict[str, Any] | None, values: dict[str, str]) -> dict[str, Any]:
@field_validator('LOGGING_CONF', mode='before')
@classmethod
def get_emails_enabled(cls, v: dict[str, Any] | None, info: FieldValidationInfo) -> dict[str, Any]:
assert info.config is not None
if isinstance(v, dict):
return v
filename = values.get('LOG_CONF_FILE', None)
filename = info.data.get('LOG_CONF_FILE', None)
if filename is not None:
with open(filename, 'r') as f:
ret = toml.loads(f.read())
......@@ -125,10 +138,7 @@ class Settings(BaseSettings):
return ret
raise ValueError('Logging config invalid!')
class Config:
case_sensitive = True
env_prefix = 'NACSOS_'
env_nested_delimiter = '__'
model_config = SettingsConfigDict(case_sensitive=True, env_prefix='NACSOS_', env_nested_delimiter='__')
conf_file = os.environ.get('NACSOS_CONFIG', 'config/default.env')
......
......@@ -13,9 +13,9 @@ if TYPE_CHECKING:
else:
AnyEvent = Union[events.BaseEvent.get_subclasses()]
AnyEventType = Literal[tuple(sc.__name__ for sc in events.BaseEvent.get_subclasses())] # noqa PyProtectedMember
AnyEventLiteral = Literal[tuple(sc._name for sc in events.BaseEvent.get_subclasses())] # noqa PyProtectedMember
AnyEventLiteral = Literal[tuple(sc.name for sc in events.BaseEvent.get_subclasses())] # noqa PyProtectedMember
# Permanent/global listeners
eventbus.on(events.ExampleEvent._name, example.test_listener) # noqa PyProtectedMember
eventbus.on(events.ExampleEvent.name, example.test_listener) # noqa PyProtectedMember
__all__ = ['eventbus', 'events', 'AnyEvent', 'AnyEventType']
......@@ -4,7 +4,7 @@ from pydantic import BaseModel
class BaseEvent(BaseModel):
_name = ClassVar[str]
name: ClassVar[str]
@classmethod
def get_subclasses(cls):
......@@ -18,9 +18,9 @@ class BaseEvent(BaseModel):
class ExampleEvent(BaseEvent):
_name = 'Example_*'
name = 'Example_*'
payload_a: str
class ExampleSubEvent(ExampleEvent):
_name = 'Example_sub'
name = 'Example_sub'
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment