Write union types as X | Y where possible (#19111)

aka PEP 604, added in Python 3.10
This commit is contained in:
Andrew Ferrazzutti 2025-11-06 15:02:33 -05:00 committed by GitHub
parent 6790312831
commit fcac7e0282
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
465 changed files with 4034 additions and 4555 deletions

View File

@ -25,7 +25,6 @@
import argparse import argparse
import os import os
import subprocess import subprocess
from typing import Optional
from zipfile import ZipFile from zipfile import ZipFile
from packaging.tags import Tag from packaging.tags import Tag
@ -80,7 +79,7 @@ def cpython(wheel_file: str, name: str, version: Version, tag: Tag) -> str:
return new_wheel_file return new_wheel_file
def main(wheel_file: str, dest_dir: str, archs: Optional[str]) -> None: def main(wheel_file: str, dest_dir: str, archs: str | None) -> None:
"""Entry point""" """Entry point"""
# Parse the wheel file name into its parts. Note that `parse_wheel_filename` # Parse the wheel file name into its parts. Note that `parse_wheel_filename`

1
changelog.d/19111.misc Normal file
View File

@ -0,0 +1 @@
Write union types as `X | Y` where possible, as per PEP 604, added in Python 3.10.

View File

@ -33,7 +33,6 @@ import sys
import time import time
import urllib import urllib
from http import TwistedHttpClient from http import TwistedHttpClient
from typing import Optional
import urlparse import urlparse
from signedjson.key import NACL_ED25519, decode_verify_key_bytes from signedjson.key import NACL_ED25519, decode_verify_key_bytes
@ -726,7 +725,7 @@ class SynapseCmd(cmd.Cmd):
method, method,
path, path,
data=None, data=None,
query_params: Optional[dict] = None, query_params: dict | None = None,
alt_text=None, alt_text=None,
): ):
"""Runs an HTTP request and pretty prints the output. """Runs an HTTP request and pretty prints the output.

View File

@ -22,7 +22,6 @@
import json import json
import urllib import urllib
from pprint import pformat from pprint import pformat
from typing import Optional
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
from twisted.web.client import Agent, readBody from twisted.web.client import Agent, readBody
@ -90,7 +89,7 @@ class TwistedHttpClient(HttpClient):
body = yield readBody(response) body = yield readBody(response)
return json.loads(body) return json.loads(body)
def _create_put_request(self, url, json_data, headers_dict: Optional[dict] = None): def _create_put_request(self, url, json_data, headers_dict: dict | None = None):
"""Wrapper of _create_request to issue a PUT request""" """Wrapper of _create_request to issue a PUT request"""
headers_dict = headers_dict or {} headers_dict = headers_dict or {}
@ -101,7 +100,7 @@ class TwistedHttpClient(HttpClient):
"PUT", url, producer=_JsonProducer(json_data), headers_dict=headers_dict "PUT", url, producer=_JsonProducer(json_data), headers_dict=headers_dict
) )
def _create_get_request(self, url, headers_dict: Optional[dict] = None): def _create_get_request(self, url, headers_dict: dict | None = None):
"""Wrapper of _create_request to issue a GET request""" """Wrapper of _create_request to issue a GET request"""
return self._create_request("GET", url, headers_dict=headers_dict or {}) return self._create_request("GET", url, headers_dict=headers_dict or {})
@ -113,7 +112,7 @@ class TwistedHttpClient(HttpClient):
data=None, data=None,
qparams=None, qparams=None,
jsonreq=True, jsonreq=True,
headers: Optional[dict] = None, headers: dict | None = None,
): ):
headers = headers or {} headers = headers or {}
@ -138,7 +137,7 @@ class TwistedHttpClient(HttpClient):
@defer.inlineCallbacks @defer.inlineCallbacks
def _create_request( def _create_request(
self, method, url, producer=None, headers_dict: Optional[dict] = None self, method, url, producer=None, headers_dict: dict | None = None
): ):
"""Creates and sends a request to the given url""" """Creates and sends a request to the given url"""
headers_dict = headers_dict or {} headers_dict = headers_dict or {}

View File

@ -68,7 +68,6 @@ from typing import (
Mapping, Mapping,
MutableMapping, MutableMapping,
NoReturn, NoReturn,
Optional,
SupportsIndex, SupportsIndex,
) )
@ -468,7 +467,7 @@ def add_worker_roles_to_shared_config(
def merge_worker_template_configs( def merge_worker_template_configs(
existing_dict: Optional[dict[str, Any]], existing_dict: dict[str, Any] | None,
to_be_merged_dict: dict[str, Any], to_be_merged_dict: dict[str, Any],
) -> dict[str, Any]: ) -> dict[str, Any]:
"""When given an existing dict of worker template configuration consisting with both """When given an existing dict of worker template configuration consisting with both
@ -1026,7 +1025,7 @@ def generate_worker_log_config(
Returns: the path to the generated file Returns: the path to the generated file
""" """
# Check whether we should write worker logs to disk, in addition to the console # Check whether we should write worker logs to disk, in addition to the console
extra_log_template_args: dict[str, Optional[str]] = {} extra_log_template_args: dict[str, str | None] = {}
if environ.get("SYNAPSE_WORKERS_WRITE_LOGS_TO_DISK"): if environ.get("SYNAPSE_WORKERS_WRITE_LOGS_TO_DISK"):
extra_log_template_args["LOG_FILE_PATH"] = f"{data_dir}/logs/{worker_name}.log" extra_log_template_args["LOG_FILE_PATH"] = f"{data_dir}/logs/{worker_name}.log"

View File

@ -6,7 +6,7 @@ import os
import platform import platform
import subprocess import subprocess
import sys import sys
from typing import Any, Mapping, MutableMapping, NoReturn, Optional from typing import Any, Mapping, MutableMapping, NoReturn
import jinja2 import jinja2
@ -50,7 +50,7 @@ def generate_config_from_template(
config_dir: str, config_dir: str,
config_path: str, config_path: str,
os_environ: Mapping[str, str], os_environ: Mapping[str, str],
ownership: Optional[str], ownership: str | None,
) -> None: ) -> None:
"""Generate a homeserver.yaml from environment variables """Generate a homeserver.yaml from environment variables
@ -147,7 +147,7 @@ def generate_config_from_template(
subprocess.run(args, check=True) subprocess.run(args, check=True)
def run_generate_config(environ: Mapping[str, str], ownership: Optional[str]) -> None: def run_generate_config(environ: Mapping[str, str], ownership: str | None) -> None:
"""Run synapse with a --generate-config param to generate a template config file """Run synapse with a --generate-config param to generate a template config file
Args: Args:

View File

@ -299,7 +299,7 @@ logcontext is not finished before the `async` processing completes.
**Bad**: **Bad**:
```python ```python
cache: Optional[ObservableDeferred[None]] = None cache: ObservableDeferred[None] | None = None
async def do_something_else( async def do_something_else(
to_resolve: Deferred[None] to_resolve: Deferred[None]
@ -326,7 +326,7 @@ with LoggingContext("request-1"):
**Good**: **Good**:
```python ```python
cache: Optional[ObservableDeferred[None]] = None cache: ObservableDeferred[None] | None = None
async def do_something_else( async def do_something_else(
to_resolve: Deferred[None] to_resolve: Deferred[None]
@ -358,7 +358,7 @@ with LoggingContext("request-1"):
**OK**: **OK**:
```python ```python
cache: Optional[ObservableDeferred[None]] = None cache: ObservableDeferred[None] | None = None
async def do_something_else( async def do_something_else(
to_resolve: Deferred[None] to_resolve: Deferred[None]

View File

@ -15,7 +15,7 @@ _First introduced in Synapse v1.57.0_
```python ```python
async def on_account_data_updated( async def on_account_data_updated(
user_id: str, user_id: str,
room_id: Optional[str], room_id: str | None,
account_data_type: str, account_data_type: str,
content: "synapse.module_api.JsonDict", content: "synapse.module_api.JsonDict",
) -> None: ) -> None:
@ -82,7 +82,7 @@ class CustomAccountDataModule:
async def log_new_account_data( async def log_new_account_data(
self, self,
user_id: str, user_id: str,
room_id: Optional[str], room_id: str | None,
account_data_type: str, account_data_type: str,
content: JsonDict, content: JsonDict,
) -> None: ) -> None:

View File

@ -12,7 +12,7 @@ The available account validity callbacks are:
_First introduced in Synapse v1.39.0_ _First introduced in Synapse v1.39.0_
```python ```python
async def is_user_expired(user: str) -> Optional[bool] async def is_user_expired(user: str) -> bool | None
``` ```
Called when processing any authenticated request (except for logout requests). The module Called when processing any authenticated request (except for logout requests). The module

View File

@ -11,7 +11,7 @@ The available media repository callbacks are:
_First introduced in Synapse v1.132.0_ _First introduced in Synapse v1.132.0_
```python ```python
async def get_media_config_for_user(user_id: str) -> Optional[JsonDict] async def get_media_config_for_user(user_id: str) -> JsonDict | None
``` ```
**<span style="color:red"> **<span style="color:red">
@ -70,7 +70,7 @@ implementations of this callback.
_First introduced in Synapse v1.139.0_ _First introduced in Synapse v1.139.0_
```python ```python
async def get_media_upload_limits_for_user(user_id: str, size: int) -> Optional[List[synapse.module_api.MediaUploadLimit]] async def get_media_upload_limits_for_user(user_id: str, size: int) -> list[synapse.module_api.MediaUploadLimit] | None
``` ```
**<span style="color:red"> **<span style="color:red">

View File

@ -23,12 +23,7 @@ async def check_auth(
user: str, user: str,
login_type: str, login_type: str,
login_dict: "synapse.module_api.JsonDict", login_dict: "synapse.module_api.JsonDict",
) -> Optional[ ) -> tuple[str, Callable[["synapse.module_api.LoginResponse"], Awaitable[None]] | None] | None
Tuple[
str,
Optional[Callable[["synapse.module_api.LoginResponse"], Awaitable[None]]]
]
]
``` ```
The login type and field names should be provided by the user in the The login type and field names should be provided by the user in the
@ -67,12 +62,7 @@ async def check_3pid_auth(
medium: str, medium: str,
address: str, address: str,
password: str, password: str,
) -> Optional[ ) -> tuple[str, Callable[["synapse.module_api.LoginResponse"], Awaitable[None]] | None]
Tuple[
str,
Optional[Callable[["synapse.module_api.LoginResponse"], Awaitable[None]]]
]
]
``` ```
Called when a user attempts to register or log in with a third party identifier, Called when a user attempts to register or log in with a third party identifier,
@ -98,7 +88,7 @@ _First introduced in Synapse v1.46.0_
```python ```python
async def on_logged_out( async def on_logged_out(
user_id: str, user_id: str,
device_id: Optional[str], device_id: str | None,
access_token: str access_token: str
) -> None ) -> None
``` ```
@ -119,7 +109,7 @@ _First introduced in Synapse v1.52.0_
async def get_username_for_registration( async def get_username_for_registration(
uia_results: Dict[str, Any], uia_results: Dict[str, Any],
params: Dict[str, Any], params: Dict[str, Any],
) -> Optional[str] ) -> str | None
``` ```
Called when registering a new user. The module can return a username to set for the user Called when registering a new user. The module can return a username to set for the user
@ -180,7 +170,7 @@ _First introduced in Synapse v1.54.0_
async def get_displayname_for_registration( async def get_displayname_for_registration(
uia_results: Dict[str, Any], uia_results: Dict[str, Any],
params: Dict[str, Any], params: Dict[str, Any],
) -> Optional[str] ) -> str | None
``` ```
Called when registering a new user. The module can return a display name to set for the Called when registering a new user. The module can return a display name to set for the
@ -259,12 +249,7 @@ class MyAuthProvider:
username: str, username: str,
login_type: str, login_type: str,
login_dict: "synapse.module_api.JsonDict", login_dict: "synapse.module_api.JsonDict",
) -> Optional[ ) -> tuple[str, Callable[["synapse.module_api.LoginResponse"], Awaitable[None]] | None] | None:
Tuple[
str,
Optional[Callable[["synapse.module_api.LoginResponse"], Awaitable[None]]],
]
]:
if login_type != "my.login_type": if login_type != "my.login_type":
return None return None
@ -276,12 +261,7 @@ class MyAuthProvider:
username: str, username: str,
login_type: str, login_type: str,
login_dict: "synapse.module_api.JsonDict", login_dict: "synapse.module_api.JsonDict",
) -> Optional[ ) -> tuple[str, Callable[["synapse.module_api.LoginResponse"], Awaitable[None]] | None] | None:
Tuple[
str,
Optional[Callable[["synapse.module_api.LoginResponse"], Awaitable[None]]],
]
]:
if login_type != "m.login.password": if login_type != "m.login.password":
return None return None

View File

@ -23,7 +23,7 @@ _First introduced in Synapse v1.42.0_
```python ```python
async def get_users_for_states( async def get_users_for_states(
state_updates: Iterable["synapse.api.UserPresenceState"], state_updates: Iterable["synapse.api.UserPresenceState"],
) -> Dict[str, Set["synapse.api.UserPresenceState"]] ) -> dict[str, set["synapse.api.UserPresenceState"]]
``` ```
**Requires** `get_interested_users` to also be registered **Requires** `get_interested_users` to also be registered
@ -45,7 +45,7 @@ _First introduced in Synapse v1.42.0_
```python ```python
async def get_interested_users( async def get_interested_users(
user_id: str user_id: str
) -> Union[Set[str], "synapse.module_api.PRESENCE_ALL_USERS"] ) -> set[str] | "synapse.module_api.PRESENCE_ALL_USERS"
``` ```
**Requires** `get_users_for_states` to also be registered **Requires** `get_users_for_states` to also be registered
@ -73,7 +73,7 @@ that `@alice:example.org` receives all presence updates from `@bob:example.com`
`@charlie:somewhere.org`, regardless of whether Alice shares a room with any of them. `@charlie:somewhere.org`, regardless of whether Alice shares a room with any of them.
```python ```python
from typing import Dict, Iterable, Set, Union from typing import Iterable
from synapse.module_api import ModuleApi from synapse.module_api import ModuleApi
@ -90,7 +90,7 @@ class CustomPresenceRouter:
async def get_users_for_states( async def get_users_for_states(
self, self,
state_updates: Iterable["synapse.api.UserPresenceState"], state_updates: Iterable["synapse.api.UserPresenceState"],
) -> Dict[str, Set["synapse.api.UserPresenceState"]]: ) -> dict[str, set["synapse.api.UserPresenceState"]]:
res = {} res = {}
for update in state_updates: for update in state_updates:
if ( if (
@ -104,7 +104,7 @@ class CustomPresenceRouter:
async def get_interested_users( async def get_interested_users(
self, self,
user_id: str, user_id: str,
) -> Union[Set[str], "synapse.module_api.PRESENCE_ALL_USERS"]: ) -> set[str] | "synapse.module_api.PRESENCE_ALL_USERS":
if user_id == "@alice:example.com": if user_id == "@alice:example.com":
return {"@bob:example.com", "@charlie:somewhere.org"} return {"@bob:example.com", "@charlie:somewhere.org"}

View File

@ -11,7 +11,7 @@ The available ratelimit callbacks are:
_First introduced in Synapse v1.132.0_ _First introduced in Synapse v1.132.0_
```python ```python
async def get_ratelimit_override_for_user(user: str, limiter_name: str) -> Optional[synapse.module_api.RatelimitOverride] async def get_ratelimit_override_for_user(user: str, limiter_name: str) -> synapse.module_api.RatelimitOverride | None
``` ```
**<span style="color:red"> **<span style="color:red">

View File

@ -331,9 +331,9 @@ search results; otherwise return `False`.
The profile is represented as a dictionary with the following keys: The profile is represented as a dictionary with the following keys:
* `user_id: str`. The Matrix ID for this user. * `user_id: str`. The Matrix ID for this user.
* `display_name: Optional[str]`. The user's display name, or `None` if this user * `display_name: str | None`. The user's display name, or `None` if this user
has not set a display name. has not set a display name.
* `avatar_url: Optional[str]`. The `mxc://` URL to the user's avatar, or `None` * `avatar_url: str | None`. The `mxc://` URL to the user's avatar, or `None`
if this user has not set an avatar. if this user has not set an avatar.
The module is given a copy of the original dictionary, so modifying it from within the The module is given a copy of the original dictionary, so modifying it from within the
@ -352,10 +352,10 @@ _First introduced in Synapse v1.37.0_
```python ```python
async def check_registration_for_spam( async def check_registration_for_spam(
email_threepid: Optional[dict], email_threepid: dict | None,
username: Optional[str], username: str | None,
request_info: Collection[Tuple[str, str]], request_info: Collection[Tuple[str, str]],
auth_provider_id: Optional[str] = None, auth_provider_id: str | None = None,
) -> "synapse.spam_checker_api.RegistrationBehaviour" ) -> "synapse.spam_checker_api.RegistrationBehaviour"
``` ```
@ -438,10 +438,10 @@ _First introduced in Synapse v1.87.0_
```python ```python
async def check_login_for_spam( async def check_login_for_spam(
user_id: str, user_id: str,
device_id: Optional[str], device_id: str | None,
initial_display_name: Optional[str], initial_display_name: str | None,
request_info: Collection[Tuple[Optional[str], str]], request_info: Collection[tuple[str | None, str]],
auth_provider_id: Optional[str] = None, auth_provider_id: str | None = None,
) -> Union["synapse.module_api.NOT_SPAM", "synapse.module_api.errors.Codes"] ) -> Union["synapse.module_api.NOT_SPAM", "synapse.module_api.errors.Codes"]
``` ```
@ -509,7 +509,7 @@ class ListSpamChecker:
resource=IsUserEvilResource(config), resource=IsUserEvilResource(config),
) )
async def check_event_for_spam(self, event: "synapse.events.EventBase") -> Union[Literal["NOT_SPAM"], Codes]: async def check_event_for_spam(self, event: "synapse.events.EventBase") -> Literal["NOT_SPAM"] | Codes:
if event.sender in self.evil_users: if event.sender in self.evil_users:
return Codes.FORBIDDEN return Codes.FORBIDDEN
else: else:

View File

@ -16,7 +16,7 @@ _First introduced in Synapse v1.39.0_
async def check_event_allowed( async def check_event_allowed(
event: "synapse.events.EventBase", event: "synapse.events.EventBase",
state_events: "synapse.types.StateMap", state_events: "synapse.types.StateMap",
) -> Tuple[bool, Optional[dict]] ) -> tuple[bool, dict | None]
``` ```
**<span style="color:red"> **<span style="color:red">
@ -340,7 +340,7 @@ class EventCensorer:
self, self,
event: "synapse.events.EventBase", event: "synapse.events.EventBase",
state_events: "synapse.types.StateMap", state_events: "synapse.types.StateMap",
) -> Tuple[bool, Optional[dict]]: ) -> Tuple[bool, dict | None]:
event_dict = event.get_dict() event_dict = event.get_dict()
new_event_content = await self.api.http_client.post_json_get_json( new_event_content = await self.api.http_client.post_json_get_json(
uri=self._endpoint, post_json=event_dict, uri=self._endpoint, post_json=event_dict,

View File

@ -76,7 +76,7 @@ possible.
#### `get_interested_users` #### `get_interested_users`
```python ```python
async def get_interested_users(self, user_id: str) -> Union[Set[str], str] async def get_interested_users(self, user_id: str) -> set[str] | str
``` ```
**Required.** An asynchronous method that is passed a single Matrix User ID. This **Required.** An asynchronous method that is passed a single Matrix User ID. This
@ -182,7 +182,7 @@ class ExamplePresenceRouter:
async def get_interested_users( async def get_interested_users(
self, self,
user_id: str, user_id: str,
) -> Union[Set[str], PresenceRouter.ALL_USERS]: ) -> set[str] | PresenceRouter.ALL_USERS:
""" """
Retrieve a list of users that `user_id` is interested in receiving the Retrieve a list of users that `user_id` is interested in receiving the
presence of. This will be in addition to those they share a room with. presence of. This will be in addition to those they share a room with.

View File

@ -80,10 +80,15 @@ select = [
"G", "G",
# pyupgrade # pyupgrade
"UP006", "UP006",
"UP007",
"UP045",
] ]
extend-safe-fixes = [ extend-safe-fixes = [
# pyupgrade # pyupgrade rules compatible with Python >= 3.9
"UP006" "UP006",
"UP007",
# pyupgrade rules compatible with Python >= 3.10
"UP045",
] ]
[tool.ruff.lint.isort] [tool.ruff.lint.isort]

View File

@ -18,7 +18,7 @@ import sys
import threading import threading
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from types import FrameType from types import FrameType
from typing import Collection, Optional, Sequence from typing import Collection, Sequence
# These are expanded inside the dockerfile to be a fully qualified image name. # These are expanded inside the dockerfile to be a fully qualified image name.
# e.g. docker.io/library/debian:bookworm # e.g. docker.io/library/debian:bookworm
@ -49,7 +49,7 @@ class Builder:
def __init__( def __init__(
self, self,
redirect_stdout: bool = False, redirect_stdout: bool = False,
docker_build_args: Optional[Sequence[str]] = None, docker_build_args: Sequence[str] | None = None,
): ):
self.redirect_stdout = redirect_stdout self.redirect_stdout = redirect_stdout
self._docker_build_args = tuple(docker_build_args or ()) self._docker_build_args = tuple(docker_build_args or ())
@ -167,7 +167,7 @@ class Builder:
def run_builds( def run_builds(
builder: Builder, dists: Collection[str], jobs: int = 1, skip_tests: bool = False builder: Builder, dists: Collection[str], jobs: int = 1, skip_tests: bool = False
) -> None: ) -> None:
def sig(signum: int, _frame: Optional[FrameType]) -> None: def sig(signum: int, _frame: FrameType | None) -> None:
print("Caught SIGINT") print("Caught SIGINT")
builder.kill_containers() builder.kill_containers()

View File

@ -43,7 +43,7 @@ import argparse
import base64 import base64
import json import json
import sys import sys
from typing import Any, Mapping, Optional, Union from typing import Any, Mapping
from urllib import parse as urlparse from urllib import parse as urlparse
import requests import requests
@ -103,12 +103,12 @@ def sign_json(
def request( def request(
method: Optional[str], method: str | None,
origin_name: str, origin_name: str,
origin_key: signedjson.types.SigningKey, origin_key: signedjson.types.SigningKey,
destination: str, destination: str,
path: str, path: str,
content: Optional[str], content: str | None,
verify_tls: bool, verify_tls: bool,
) -> requests.Response: ) -> requests.Response:
if method is None: if method is None:
@ -301,9 +301,9 @@ class MatrixConnectionAdapter(HTTPAdapter):
def get_connection_with_tls_context( def get_connection_with_tls_context(
self, self,
request: PreparedRequest, request: PreparedRequest,
verify: Optional[Union[bool, str]], verify: bool | str | None,
proxies: Optional[Mapping[str, str]] = None, proxies: Mapping[str, str] | None = None,
cert: Optional[Union[tuple[str, str], str]] = None, cert: tuple[str, str] | str | None = None,
) -> HTTPConnectionPool: ) -> HTTPConnectionPool:
# overrides the get_connection_with_tls_context() method in the base class # overrides the get_connection_with_tls_context() method in the base class
parsed = urlparse.urlsplit(request.url) parsed = urlparse.urlsplit(request.url)
@ -368,7 +368,7 @@ class MatrixConnectionAdapter(HTTPAdapter):
return server_name, 8448, server_name return server_name, 8448, server_name
@staticmethod @staticmethod
def _get_well_known(server_name: str) -> Optional[str]: def _get_well_known(server_name: str) -> str | None:
if ":" in server_name: if ":" in server_name:
# explicit port, or ipv6 literal. Either way, no .well-known # explicit port, or ipv6 literal. Either way, no .well-known
return None return None

View File

@ -4,7 +4,7 @@
import json import json
import re import re
import sys import sys
from typing import Any, Optional from typing import Any
import yaml import yaml
@ -259,17 +259,17 @@ def indent(text: str, first_line: bool = True) -> str:
return text return text
def em(s: Optional[str]) -> str: def em(s: str | None) -> str:
"""Add emphasis to text.""" """Add emphasis to text."""
return f"*{s}*" if s else "" return f"*{s}*" if s else ""
def a(s: Optional[str], suffix: str = " ") -> str: def a(s: str | None, suffix: str = " ") -> str:
"""Appends a space if the given string is not empty.""" """Appends a space if the given string is not empty."""
return s + suffix if s else "" return s + suffix if s else ""
def p(s: Optional[str], prefix: str = " ") -> str: def p(s: str | None, prefix: str = " ") -> str:
"""Prepend a space if the given string is not empty.""" """Prepend a space if the given string is not empty."""
return prefix + s if s else "" return prefix + s if s else ""

View File

@ -24,7 +24,7 @@ can crop up, e.g the cache descriptors.
""" """
import enum import enum
from typing import Callable, Mapping, Optional, Union from typing import Callable, Mapping
import attr import attr
import mypy.types import mypy.types
@ -123,7 +123,7 @@ class ArgLocation:
""" """
prometheus_metric_fullname_to_label_arg_map: Mapping[str, Optional[ArgLocation]] = { prometheus_metric_fullname_to_label_arg_map: Mapping[str, ArgLocation | None] = {
# `Collector` subclasses: # `Collector` subclasses:
"prometheus_client.metrics.MetricWrapperBase": ArgLocation("labelnames", 2), "prometheus_client.metrics.MetricWrapperBase": ArgLocation("labelnames", 2),
"prometheus_client.metrics.Counter": ArgLocation("labelnames", 2), "prometheus_client.metrics.Counter": ArgLocation("labelnames", 2),
@ -211,7 +211,7 @@ class SynapsePlugin(Plugin):
def get_base_class_hook( def get_base_class_hook(
self, fullname: str self, fullname: str
) -> Optional[Callable[[ClassDefContext], None]]: ) -> Callable[[ClassDefContext], None] | None:
def _get_base_class_hook(ctx: ClassDefContext) -> None: def _get_base_class_hook(ctx: ClassDefContext) -> None:
# Run any `get_base_class_hook` checks from other plugins first. # Run any `get_base_class_hook` checks from other plugins first.
# #
@ -232,7 +232,7 @@ class SynapsePlugin(Plugin):
def get_function_signature_hook( def get_function_signature_hook(
self, fullname: str self, fullname: str
) -> Optional[Callable[[FunctionSigContext], FunctionLike]]: ) -> Callable[[FunctionSigContext], FunctionLike] | None:
# Strip off the unique identifier for classes that are dynamically created inside # Strip off the unique identifier for classes that are dynamically created inside
# functions. ex. `synapse.metrics.jemalloc.JemallocCollector@185` (this is the line # functions. ex. `synapse.metrics.jemalloc.JemallocCollector@185` (this is the line
# number) # number)
@ -262,7 +262,7 @@ class SynapsePlugin(Plugin):
def get_method_signature_hook( def get_method_signature_hook(
self, fullname: str self, fullname: str
) -> Optional[Callable[[MethodSigContext], CallableType]]: ) -> Callable[[MethodSigContext], CallableType] | None:
if fullname.startswith( if fullname.startswith(
( (
"synapse.util.caches.descriptors.CachedFunction.__call__", "synapse.util.caches.descriptors.CachedFunction.__call__",
@ -721,7 +721,7 @@ def check_is_cacheable_wrapper(ctx: MethodSigContext) -> CallableType:
def check_is_cacheable( def check_is_cacheable(
signature: CallableType, signature: CallableType,
ctx: Union[MethodSigContext, FunctionSigContext], ctx: MethodSigContext | FunctionSigContext,
) -> None: ) -> None:
""" """
Check if a callable returns a type which can be cached. Check if a callable returns a type which can be cached.
@ -795,7 +795,7 @@ AT_CACHED_MUTABLE_RETURN = ErrorCode(
def is_cacheable( def is_cacheable(
rt: mypy.types.Type, signature: CallableType, verbose: bool rt: mypy.types.Type, signature: CallableType, verbose: bool
) -> tuple[bool, Optional[str]]: ) -> tuple[bool, str | None]:
""" """
Check if a particular type is cachable. Check if a particular type is cachable.

View File

@ -32,7 +32,7 @@ import time
import urllib.request import urllib.request
from os import path from os import path
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
from typing import Any, Match, Optional, Union from typing import Any, Match
import attr import attr
import click import click
@ -327,11 +327,11 @@ def _prepare() -> None:
@cli.command() @cli.command()
@click.option("--gh-token", envvar=["GH_TOKEN", "GITHUB_TOKEN"]) @click.option("--gh-token", envvar=["GH_TOKEN", "GITHUB_TOKEN"])
def tag(gh_token: Optional[str]) -> None: def tag(gh_token: str | None) -> None:
_tag(gh_token) _tag(gh_token)
def _tag(gh_token: Optional[str]) -> None: def _tag(gh_token: str | None) -> None:
"""Tags the release and generates a draft GitHub release""" """Tags the release and generates a draft GitHub release"""
# Test that the GH Token is valid before continuing. # Test that the GH Token is valid before continuing.
@ -471,11 +471,11 @@ def _publish(gh_token: str) -> None:
@cli.command() @cli.command()
@click.option("--gh-token", envvar=["GH_TOKEN", "GITHUB_TOKEN"], required=False) @click.option("--gh-token", envvar=["GH_TOKEN", "GITHUB_TOKEN"], required=False)
def upload(gh_token: Optional[str]) -> None: def upload(gh_token: str | None) -> None:
_upload(gh_token) _upload(gh_token)
def _upload(gh_token: Optional[str]) -> None: def _upload(gh_token: str | None) -> None:
"""Upload release to pypi.""" """Upload release to pypi."""
# Test that the GH Token is valid before continuing. # Test that the GH Token is valid before continuing.
@ -576,11 +576,11 @@ def _merge_into(repo: Repo, source: str, target: str) -> None:
@cli.command() @cli.command()
@click.option("--gh-token", envvar=["GH_TOKEN", "GITHUB_TOKEN"], required=False) @click.option("--gh-token", envvar=["GH_TOKEN", "GITHUB_TOKEN"], required=False)
def wait_for_actions(gh_token: Optional[str]) -> None: def wait_for_actions(gh_token: str | None) -> None:
_wait_for_actions(gh_token) _wait_for_actions(gh_token)
def _wait_for_actions(gh_token: Optional[str]) -> None: def _wait_for_actions(gh_token: str | None) -> None:
# Test that the GH Token is valid before continuing. # Test that the GH Token is valid before continuing.
check_valid_gh_token(gh_token) check_valid_gh_token(gh_token)
@ -658,7 +658,7 @@ def _notify(message: str) -> None:
envvar=["GH_TOKEN", "GITHUB_TOKEN"], envvar=["GH_TOKEN", "GITHUB_TOKEN"],
required=False, required=False,
) )
def merge_back(_gh_token: Optional[str]) -> None: def merge_back(_gh_token: str | None) -> None:
_merge_back() _merge_back()
@ -715,7 +715,7 @@ def _merge_back() -> None:
envvar=["GH_TOKEN", "GITHUB_TOKEN"], envvar=["GH_TOKEN", "GITHUB_TOKEN"],
required=False, required=False,
) )
def announce(_gh_token: Optional[str]) -> None: def announce(_gh_token: str | None) -> None:
_announce() _announce()
@ -851,7 +851,7 @@ def get_repo_and_check_clean_checkout(
return repo return repo
def check_valid_gh_token(gh_token: Optional[str]) -> None: def check_valid_gh_token(gh_token: str | None) -> None:
"""Check that a github token is valid, if supplied""" """Check that a github token is valid, if supplied"""
if not gh_token: if not gh_token:
@ -867,7 +867,7 @@ def check_valid_gh_token(gh_token: Optional[str]) -> None:
raise click.ClickException(f"Github credentials are bad: {e}") raise click.ClickException(f"Github credentials are bad: {e}")
def find_ref(repo: git.Repo, ref_name: str) -> Optional[git.HEAD]: def find_ref(repo: git.Repo, ref_name: str) -> git.HEAD | None:
"""Find the branch/ref, looking first locally then in the remote.""" """Find the branch/ref, looking first locally then in the remote."""
if ref_name in repo.references: if ref_name in repo.references:
return repo.references[ref_name] return repo.references[ref_name]
@ -904,7 +904,7 @@ def get_changes_for_version(wanted_version: version.Version) -> str:
# These are 0-based. # These are 0-based.
start_line: int start_line: int
end_line: Optional[int] = None # Is none if its the last entry end_line: int | None = None # Is none if its the last entry
headings: list[VersionSection] = [] headings: list[VersionSection] = []
for i, token in enumerate(tokens): for i, token in enumerate(tokens):
@ -991,7 +991,7 @@ def build_dependabot_changelog(repo: Repo, current_version: version.Version) ->
messages = [] messages = []
for commit in reversed(commits): for commit in reversed(commits):
if commit.author.name == "dependabot[bot]": if commit.author.name == "dependabot[bot]":
message: Union[str, bytes] = commit.message message: str | bytes = commit.message
if isinstance(message, bytes): if isinstance(message, bytes):
message = message.decode("utf-8") message = message.decode("utf-8")
messages.append(message.split("\n", maxsplit=1)[0]) messages.append(message.split("\n", maxsplit=1)[0])

View File

@ -38,7 +38,7 @@ import io
import json import json
import sys import sys
from collections import defaultdict from collections import defaultdict
from typing import Any, Iterator, Optional from typing import Any, Iterator
import git import git
from packaging import version from packaging import version
@ -57,7 +57,7 @@ SCHEMA_VERSION_FILES = (
OLDEST_SHOWN_VERSION = version.parse("v1.0") OLDEST_SHOWN_VERSION = version.parse("v1.0")
def get_schema_versions(tag: git.Tag) -> tuple[Optional[int], Optional[int]]: def get_schema_versions(tag: git.Tag) -> tuple[int | None, int | None]:
"""Get the schema and schema compat versions for a tag.""" """Get the schema and schema compat versions for a tag."""
schema_version = None schema_version = None
schema_compat_version = None schema_compat_version = None

View File

@ -13,10 +13,8 @@ from typing import (
Iterator, Iterator,
KeysView, KeysView,
Mapping, Mapping,
Optional,
Sequence, Sequence,
TypeVar, TypeVar,
Union,
ValuesView, ValuesView,
overload, overload,
) )
@ -51,7 +49,7 @@ class SortedDict(dict[_KT, _VT]):
self, __key: _Key[_KT], __iterable: Iterable[tuple[_KT, _VT]], **kwargs: _VT self, __key: _Key[_KT], __iterable: Iterable[tuple[_KT, _VT]], **kwargs: _VT
) -> None: ... ) -> None: ...
@property @property
def key(self) -> Optional[_Key[_KT]]: ... def key(self) -> _Key[_KT] | None: ...
@property @property
def iloc(self) -> SortedKeysView[_KT]: ... def iloc(self) -> SortedKeysView[_KT]: ...
def clear(self) -> None: ... def clear(self) -> None: ...
@ -79,10 +77,10 @@ class SortedDict(dict[_KT, _VT]):
@overload @overload
def pop(self, key: _KT) -> _VT: ... def pop(self, key: _KT) -> _VT: ...
@overload @overload
def pop(self, key: _KT, default: _T = ...) -> Union[_VT, _T]: ... def pop(self, key: _KT, default: _T = ...) -> _VT | _T: ...
def popitem(self, index: int = ...) -> tuple[_KT, _VT]: ... def popitem(self, index: int = ...) -> tuple[_KT, _VT]: ...
def peekitem(self, index: int = ...) -> tuple[_KT, _VT]: ... def peekitem(self, index: int = ...) -> tuple[_KT, _VT]: ...
def setdefault(self, key: _KT, default: Optional[_VT] = ...) -> _VT: ... def setdefault(self, key: _KT, default: _VT | None = ...) -> _VT: ...
# Mypy now reports the first overload as an error, because typeshed widened the type # Mypy now reports the first overload as an error, because typeshed widened the type
# of `__map` to its internal `_typeshed.SupportsKeysAndGetItem` type in # of `__map` to its internal `_typeshed.SupportsKeysAndGetItem` type in
# https://github.com/python/typeshed/pull/6653 # https://github.com/python/typeshed/pull/6653
@ -106,8 +104,8 @@ class SortedDict(dict[_KT, _VT]):
def _check(self) -> None: ... def _check(self) -> None: ...
def islice( def islice(
self, self,
start: Optional[int] = ..., start: int | None = ...,
stop: Optional[int] = ..., stop: int | None = ...,
reverse: bool = ..., reverse: bool = ...,
) -> Iterator[_KT]: ... ) -> Iterator[_KT]: ...
def bisect_left(self, value: _KT) -> int: ... def bisect_left(self, value: _KT) -> int: ...
@ -118,7 +116,7 @@ class SortedKeysView(KeysView[_KT_co], Sequence[_KT_co]):
def __getitem__(self, index: int) -> _KT_co: ... def __getitem__(self, index: int) -> _KT_co: ...
@overload @overload
def __getitem__(self, index: slice) -> list[_KT_co]: ... def __getitem__(self, index: slice) -> list[_KT_co]: ...
def __delitem__(self, index: Union[int, slice]) -> None: ... def __delitem__(self, index: int | slice) -> None: ...
class SortedItemsView(ItemsView[_KT_co, _VT_co], Sequence[tuple[_KT_co, _VT_co]]): class SortedItemsView(ItemsView[_KT_co, _VT_co], Sequence[tuple[_KT_co, _VT_co]]):
def __iter__(self) -> Iterator[tuple[_KT_co, _VT_co]]: ... def __iter__(self) -> Iterator[tuple[_KT_co, _VT_co]]: ...
@ -126,11 +124,11 @@ class SortedItemsView(ItemsView[_KT_co, _VT_co], Sequence[tuple[_KT_co, _VT_co]]
def __getitem__(self, index: int) -> tuple[_KT_co, _VT_co]: ... def __getitem__(self, index: int) -> tuple[_KT_co, _VT_co]: ...
@overload @overload
def __getitem__(self, index: slice) -> list[tuple[_KT_co, _VT_co]]: ... def __getitem__(self, index: slice) -> list[tuple[_KT_co, _VT_co]]: ...
def __delitem__(self, index: Union[int, slice]) -> None: ... def __delitem__(self, index: int | slice) -> None: ...
class SortedValuesView(ValuesView[_VT_co], Sequence[_VT_co]): class SortedValuesView(ValuesView[_VT_co], Sequence[_VT_co]):
@overload @overload
def __getitem__(self, index: int) -> _VT_co: ... def __getitem__(self, index: int) -> _VT_co: ...
@overload @overload
def __getitem__(self, index: slice) -> list[_VT_co]: ... def __getitem__(self, index: slice) -> list[_VT_co]: ...
def __delitem__(self, index: Union[int, slice]) -> None: ... def __delitem__(self, index: int | slice) -> None: ...

View File

@ -10,10 +10,8 @@ from typing import (
Iterable, Iterable,
Iterator, Iterator,
MutableSequence, MutableSequence,
Optional,
Sequence, Sequence,
TypeVar, TypeVar,
Union,
overload, overload,
) )
@ -29,8 +27,8 @@ class SortedList(MutableSequence[_T]):
DEFAULT_LOAD_FACTOR: int = ... DEFAULT_LOAD_FACTOR: int = ...
def __init__( def __init__(
self, self,
iterable: Optional[Iterable[_T]] = ..., iterable: Iterable[_T] | None = ...,
key: Optional[_Key[_T]] = ..., key: _Key[_T] | None = ...,
): ... ): ...
# NB: currently mypy does not honour return type, see mypy #3307 # NB: currently mypy does not honour return type, see mypy #3307
@overload @overload
@ -42,7 +40,7 @@ class SortedList(MutableSequence[_T]):
@overload @overload
def __new__(cls, iterable: Iterable[_T], key: _Key[_T]) -> SortedKeyList[_T]: ... def __new__(cls, iterable: Iterable[_T], key: _Key[_T]) -> SortedKeyList[_T]: ...
@property @property
def key(self) -> Optional[Callable[[_T], Any]]: ... def key(self) -> Callable[[_T], Any] | None: ...
def _reset(self, load: int) -> None: ... def _reset(self, load: int) -> None: ...
def clear(self) -> None: ... def clear(self) -> None: ...
def _clear(self) -> None: ... def _clear(self) -> None: ...
@ -57,7 +55,7 @@ class SortedList(MutableSequence[_T]):
def _pos(self, idx: int) -> int: ... def _pos(self, idx: int) -> int: ...
def _build_index(self) -> None: ... def _build_index(self) -> None: ...
def __contains__(self, value: Any) -> bool: ... def __contains__(self, value: Any) -> bool: ...
def __delitem__(self, index: Union[int, slice]) -> None: ... def __delitem__(self, index: int | slice) -> None: ...
@overload @overload
def __getitem__(self, index: int) -> _T: ... def __getitem__(self, index: int) -> _T: ...
@overload @overload
@ -76,8 +74,8 @@ class SortedList(MutableSequence[_T]):
def reverse(self) -> None: ... def reverse(self) -> None: ...
def islice( def islice(
self, self,
start: Optional[int] = ..., start: int | None = ...,
stop: Optional[int] = ..., stop: int | None = ...,
reverse: bool = ..., reverse: bool = ...,
) -> Iterator[_T]: ... ) -> Iterator[_T]: ...
def _islice( def _islice(
@ -90,8 +88,8 @@ class SortedList(MutableSequence[_T]):
) -> Iterator[_T]: ... ) -> Iterator[_T]: ...
def irange( def irange(
self, self,
minimum: Optional[int] = ..., minimum: int | None = ...,
maximum: Optional[int] = ..., maximum: int | None = ...,
inclusive: tuple[bool, bool] = ..., inclusive: tuple[bool, bool] = ...,
reverse: bool = ..., reverse: bool = ...,
) -> Iterator[_T]: ... ) -> Iterator[_T]: ...
@ -107,7 +105,7 @@ class SortedList(MutableSequence[_T]):
def insert(self, index: int, value: _T) -> None: ... def insert(self, index: int, value: _T) -> None: ...
def pop(self, index: int = ...) -> _T: ... def pop(self, index: int = ...) -> _T: ...
def index( def index(
self, value: _T, start: Optional[int] = ..., stop: Optional[int] = ... self, value: _T, start: int | None = ..., stop: int | None = ...
) -> int: ... ) -> int: ...
def __add__(self: _SL, other: Iterable[_T]) -> _SL: ... def __add__(self: _SL, other: Iterable[_T]) -> _SL: ...
def __radd__(self: _SL, other: Iterable[_T]) -> _SL: ... def __radd__(self: _SL, other: Iterable[_T]) -> _SL: ...
@ -126,10 +124,10 @@ class SortedList(MutableSequence[_T]):
class SortedKeyList(SortedList[_T]): class SortedKeyList(SortedList[_T]):
def __init__( def __init__(
self, iterable: Optional[Iterable[_T]] = ..., key: _Key[_T] = ... self, iterable: Iterable[_T] | None = ..., key: _Key[_T] = ...
) -> None: ... ) -> None: ...
def __new__( def __new__(
cls, iterable: Optional[Iterable[_T]] = ..., key: _Key[_T] = ... cls, iterable: Iterable[_T] | None = ..., key: _Key[_T] = ...
) -> SortedKeyList[_T]: ... ) -> SortedKeyList[_T]: ...
@property @property
def key(self) -> Callable[[_T], Any]: ... def key(self) -> Callable[[_T], Any]: ...
@ -146,15 +144,15 @@ class SortedKeyList(SortedList[_T]):
def _delete(self, pos: int, idx: int) -> None: ... def _delete(self, pos: int, idx: int) -> None: ...
def irange( def irange(
self, self,
minimum: Optional[int] = ..., minimum: int | None = ...,
maximum: Optional[int] = ..., maximum: int | None = ...,
inclusive: tuple[bool, bool] = ..., inclusive: tuple[bool, bool] = ...,
reverse: bool = ..., reverse: bool = ...,
) -> Iterator[_T]: ... ) -> Iterator[_T]: ...
def irange_key( def irange_key(
self, self,
min_key: Optional[Any] = ..., min_key: Any | None = ...,
max_key: Optional[Any] = ..., max_key: Any | None = ...,
inclusive: tuple[bool, bool] = ..., inclusive: tuple[bool, bool] = ...,
reserve: bool = ..., reserve: bool = ...,
) -> Iterator[_T]: ... ) -> Iterator[_T]: ...
@ -170,7 +168,7 @@ class SortedKeyList(SortedList[_T]):
def copy(self: _SKL) -> _SKL: ... def copy(self: _SKL) -> _SKL: ...
def __copy__(self: _SKL) -> _SKL: ... def __copy__(self: _SKL) -> _SKL: ...
def index( def index(
self, value: _T, start: Optional[int] = ..., stop: Optional[int] = ... self, value: _T, start: int | None = ..., stop: int | None = ...
) -> int: ... ) -> int: ...
def __add__(self: _SKL, other: Iterable[_T]) -> _SKL: ... def __add__(self: _SKL, other: Iterable[_T]) -> _SKL: ...
def __radd__(self: _SKL, other: Iterable[_T]) -> _SKL: ... def __radd__(self: _SKL, other: Iterable[_T]) -> _SKL: ...

View File

@ -11,10 +11,8 @@ from typing import (
Iterable, Iterable,
Iterator, Iterator,
MutableSet, MutableSet,
Optional,
Sequence, Sequence,
TypeVar, TypeVar,
Union,
overload, overload,
) )
@ -28,21 +26,19 @@ _Key = Callable[[_T], Any]
class SortedSet(MutableSet[_T], Sequence[_T]): class SortedSet(MutableSet[_T], Sequence[_T]):
def __init__( def __init__(
self, self,
iterable: Optional[Iterable[_T]] = ..., iterable: Iterable[_T] | None = ...,
key: Optional[_Key[_T]] = ..., key: _Key[_T] | None = ...,
) -> None: ... ) -> None: ...
@classmethod @classmethod
def _fromset( def _fromset(cls, values: set[_T], key: _Key[_T] | None = ...) -> SortedSet[_T]: ...
cls, values: set[_T], key: Optional[_Key[_T]] = ...
) -> SortedSet[_T]: ...
@property @property
def key(self) -> Optional[_Key[_T]]: ... def key(self) -> _Key[_T] | None: ...
def __contains__(self, value: Any) -> bool: ... def __contains__(self, value: Any) -> bool: ...
@overload @overload
def __getitem__(self, index: int) -> _T: ... def __getitem__(self, index: int) -> _T: ...
@overload @overload
def __getitem__(self, index: slice) -> list[_T]: ... def __getitem__(self, index: slice) -> list[_T]: ...
def __delitem__(self, index: Union[int, slice]) -> None: ... def __delitem__(self, index: int | slice) -> None: ...
def __eq__(self, other: Any) -> bool: ... def __eq__(self, other: Any) -> bool: ...
def __ne__(self, other: Any) -> bool: ... def __ne__(self, other: Any) -> bool: ...
def __lt__(self, other: Iterable[_T]) -> bool: ... def __lt__(self, other: Iterable[_T]) -> bool: ...
@ -62,32 +58,28 @@ class SortedSet(MutableSet[_T], Sequence[_T]):
def _discard(self, value: _T) -> None: ... def _discard(self, value: _T) -> None: ...
def pop(self, index: int = ...) -> _T: ... def pop(self, index: int = ...) -> _T: ...
def remove(self, value: _T) -> None: ... def remove(self, value: _T) -> None: ...
def difference(self, *iterables: Iterable[_S]) -> SortedSet[Union[_T, _S]]: ... def difference(self, *iterables: Iterable[_S]) -> SortedSet[_T | _S]: ...
def __sub__(self, *iterables: Iterable[_S]) -> SortedSet[Union[_T, _S]]: ... def __sub__(self, *iterables: Iterable[_S]) -> SortedSet[_T | _S]: ...
def difference_update( def difference_update(self, *iterables: Iterable[_S]) -> SortedSet[_T | _S]: ...
self, *iterables: Iterable[_S] def __isub__(self, *iterables: Iterable[_S]) -> SortedSet[_T | _S]: ...
) -> SortedSet[Union[_T, _S]]: ... def intersection(self, *iterables: Iterable[_S]) -> SortedSet[_T | _S]: ...
def __isub__(self, *iterables: Iterable[_S]) -> SortedSet[Union[_T, _S]]: ... def __and__(self, *iterables: Iterable[_S]) -> SortedSet[_T | _S]: ...
def intersection(self, *iterables: Iterable[_S]) -> SortedSet[Union[_T, _S]]: ... def __rand__(self, *iterables: Iterable[_S]) -> SortedSet[_T | _S]: ...
def __and__(self, *iterables: Iterable[_S]) -> SortedSet[Union[_T, _S]]: ... def intersection_update(self, *iterables: Iterable[_S]) -> SortedSet[_T | _S]: ...
def __rand__(self, *iterables: Iterable[_S]) -> SortedSet[Union[_T, _S]]: ... def __iand__(self, *iterables: Iterable[_S]) -> SortedSet[_T | _S]: ...
def intersection_update( def symmetric_difference(self, other: Iterable[_S]) -> SortedSet[_T | _S]: ...
self, *iterables: Iterable[_S] def __xor__(self, other: Iterable[_S]) -> SortedSet[_T | _S]: ...
) -> SortedSet[Union[_T, _S]]: ... def __rxor__(self, other: Iterable[_S]) -> SortedSet[_T | _S]: ...
def __iand__(self, *iterables: Iterable[_S]) -> SortedSet[Union[_T, _S]]: ...
def symmetric_difference(self, other: Iterable[_S]) -> SortedSet[Union[_T, _S]]: ...
def __xor__(self, other: Iterable[_S]) -> SortedSet[Union[_T, _S]]: ...
def __rxor__(self, other: Iterable[_S]) -> SortedSet[Union[_T, _S]]: ...
def symmetric_difference_update( def symmetric_difference_update(
self, other: Iterable[_S] self, other: Iterable[_S]
) -> SortedSet[Union[_T, _S]]: ... ) -> SortedSet[_T | _S]: ...
def __ixor__(self, other: Iterable[_S]) -> SortedSet[Union[_T, _S]]: ... def __ixor__(self, other: Iterable[_S]) -> SortedSet[_T | _S]: ...
def union(self, *iterables: Iterable[_S]) -> SortedSet[Union[_T, _S]]: ... def union(self, *iterables: Iterable[_S]) -> SortedSet[_T | _S]: ...
def __or__(self, *iterables: Iterable[_S]) -> SortedSet[Union[_T, _S]]: ... def __or__(self, *iterables: Iterable[_S]) -> SortedSet[_T | _S]: ...
def __ror__(self, *iterables: Iterable[_S]) -> SortedSet[Union[_T, _S]]: ... def __ror__(self, *iterables: Iterable[_S]) -> SortedSet[_T | _S]: ...
def update(self, *iterables: Iterable[_S]) -> SortedSet[Union[_T, _S]]: ... def update(self, *iterables: Iterable[_S]) -> SortedSet[_T | _S]: ...
def __ior__(self, *iterables: Iterable[_S]) -> SortedSet[Union[_T, _S]]: ... def __ior__(self, *iterables: Iterable[_S]) -> SortedSet[_T | _S]: ...
def _update(self, *iterables: Iterable[_S]) -> SortedSet[Union[_T, _S]]: ... def _update(self, *iterables: Iterable[_S]) -> SortedSet[_T | _S]: ...
def __reduce__( def __reduce__(
self, self,
) -> tuple[type[SortedSet[_T]], set[_T], Callable[[_T], Any]]: ... ) -> tuple[type[SortedSet[_T]], set[_T], Callable[[_T], Any]]: ...
@ -97,18 +89,18 @@ class SortedSet(MutableSet[_T], Sequence[_T]):
def bisect_right(self, value: _T) -> int: ... def bisect_right(self, value: _T) -> int: ...
def islice( def islice(
self, self,
start: Optional[int] = ..., start: int | None = ...,
stop: Optional[int] = ..., stop: int | None = ...,
reverse: bool = ..., reverse: bool = ...,
) -> Iterator[_T]: ... ) -> Iterator[_T]: ...
def irange( def irange(
self, self,
minimum: Optional[_T] = ..., minimum: _T | None = ...,
maximum: Optional[_T] = ..., maximum: _T | None = ...,
inclusive: tuple[bool, bool] = ..., inclusive: tuple[bool, bool] = ...,
reverse: bool = ..., reverse: bool = ...,
) -> Iterator[_T]: ... ) -> Iterator[_T]: ...
def index( def index(
self, value: _T, start: Optional[int] = ..., stop: Optional[int] = ... self, value: _T, start: int | None = ..., stop: int | None = ...
) -> int: ... ) -> int: ...
def _reset(self, load: int) -> None: ... def _reset(self, load: int) -> None: ...

View File

@ -15,7 +15,7 @@
"""Contains *incomplete* type hints for txredisapi.""" """Contains *incomplete* type hints for txredisapi."""
from typing import Any, Optional, Union from typing import Any
from twisted.internet import protocol from twisted.internet import protocol
from twisted.internet.defer import Deferred from twisted.internet.defer import Deferred
@ -29,8 +29,8 @@ class RedisProtocol(protocol.Protocol):
self, self,
key: str, key: str,
value: Any, value: Any,
expire: Optional[int] = None, expire: int | None = None,
pexpire: Optional[int] = None, pexpire: int | None = None,
only_if_not_exists: bool = False, only_if_not_exists: bool = False,
only_if_exists: bool = False, only_if_exists: bool = False,
) -> "Deferred[None]": ... ) -> "Deferred[None]": ...
@ -38,8 +38,8 @@ class RedisProtocol(protocol.Protocol):
class SubscriberProtocol(RedisProtocol): class SubscriberProtocol(RedisProtocol):
def __init__(self, *args: object, **kwargs: object): ... def __init__(self, *args: object, **kwargs: object): ...
password: Optional[str] password: str | None
def subscribe(self, channels: Union[str, list[str]]) -> "Deferred[None]": ... def subscribe(self, channels: str | list[str]) -> "Deferred[None]": ...
def connectionMade(self) -> None: ... def connectionMade(self) -> None: ...
# type-ignore: twisted.internet.protocol.Protocol provides a default argument for # type-ignore: twisted.internet.protocol.Protocol provides a default argument for
# `reason`. txredisapi's LineReceiver Protocol doesn't. But that's fine: it's what's # `reason`. txredisapi's LineReceiver Protocol doesn't. But that's fine: it's what's
@ -49,12 +49,12 @@ class SubscriberProtocol(RedisProtocol):
def lazyConnection( def lazyConnection(
host: str = ..., host: str = ...,
port: int = ..., port: int = ...,
dbid: Optional[int] = ..., dbid: int | None = ...,
reconnect: bool = ..., reconnect: bool = ...,
charset: str = ..., charset: str = ...,
password: Optional[str] = ..., password: str | None = ...,
connectTimeout: Optional[int] = ..., connectTimeout: int | None = ...,
replyTimeout: Optional[int] = ..., replyTimeout: int | None = ...,
convertNumbers: bool = ..., convertNumbers: bool = ...,
) -> RedisProtocol: ... ) -> RedisProtocol: ...
@ -70,18 +70,18 @@ class RedisFactory(protocol.ReconnectingClientFactory):
continueTrying: bool continueTrying: bool
handler: ConnectionHandler handler: ConnectionHandler
pool: list[RedisProtocol] pool: list[RedisProtocol]
replyTimeout: Optional[int] replyTimeout: int | None
def __init__( def __init__(
self, self,
uuid: str, uuid: str,
dbid: Optional[int], dbid: int | None,
poolsize: int, poolsize: int,
isLazy: bool = False, isLazy: bool = False,
handler: type = ConnectionHandler, handler: type = ConnectionHandler,
charset: str = "utf-8", charset: str = "utf-8",
password: Optional[str] = None, password: str | None = None,
replyTimeout: Optional[int] = None, replyTimeout: int | None = None,
convertNumbers: Optional[int] = True, convertNumbers: int | None = True,
): ... ): ...
def buildProtocol(self, addr: IAddress) -> RedisProtocol: ... def buildProtocol(self, addr: IAddress) -> RedisProtocol: ...

View File

@ -22,13 +22,13 @@
import argparse import argparse
import sys import sys
import time import time
from typing import NoReturn, Optional from typing import NoReturn
from signedjson.key import encode_verify_key_base64, get_verify_key, read_signing_keys from signedjson.key import encode_verify_key_base64, get_verify_key, read_signing_keys
from signedjson.types import VerifyKey from signedjson.types import VerifyKey
def exit(status: int = 0, message: Optional[str] = None) -> NoReturn: def exit(status: int = 0, message: str | None = None) -> NoReturn:
if message: if message:
print(message, file=sys.stderr) print(message, file=sys.stderr)
sys.exit(status) sys.exit(status)

View File

@ -25,7 +25,7 @@ import logging
import re import re
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass from dataclasses import dataclass
from typing import Iterable, Optional, Pattern from typing import Iterable, Pattern
import yaml import yaml
@ -46,7 +46,7 @@ logger = logging.getLogger("generate_workers_map")
class MockHomeserver(HomeServer): class MockHomeserver(HomeServer):
DATASTORE_CLASS = DataStore DATASTORE_CLASS = DataStore
def __init__(self, config: HomeServerConfig, worker_app: Optional[str]) -> None: def __init__(self, config: HomeServerConfig, worker_app: str | None) -> None:
super().__init__(config.server.server_name, config=config) super().__init__(config.server.server_name, config=config)
self.config.worker.worker_app = worker_app self.config.worker.worker_app = worker_app
@ -65,7 +65,7 @@ class EndpointDescription:
# The category of this endpoint. Is read from the `CATEGORY` constant in the servlet # The category of this endpoint. Is read from the `CATEGORY` constant in the servlet
# class. # class.
category: Optional[str] category: str | None
# TODO: # TODO:
# - does it need to be routed based on a stream writer config? # - does it need to be routed based on a stream writer config?
@ -141,7 +141,7 @@ def get_registered_paths_for_hs(
def get_registered_paths_for_default( def get_registered_paths_for_default(
worker_app: Optional[str], base_config: HomeServerConfig worker_app: str | None, base_config: HomeServerConfig
) -> dict[tuple[str, str], EndpointDescription]: ) -> dict[tuple[str, str], EndpointDescription]:
""" """
Given the name of a worker application and a base homeserver configuration, Given the name of a worker application and a base homeserver configuration,
@ -271,7 +271,7 @@ def main() -> None:
# TODO SSO endpoints (pick_idp etc) NOT REGISTERED BY THIS SCRIPT # TODO SSO endpoints (pick_idp etc) NOT REGISTERED BY THIS SCRIPT
categories_to_methods_and_paths: dict[ categories_to_methods_and_paths: dict[
Optional[str], dict[tuple[str, str], EndpointDescription] str | None, dict[tuple[str, str], EndpointDescription]
] = defaultdict(dict) ] = defaultdict(dict)
for (method, path), desc in elided_worker_paths.items(): for (method, path), desc in elided_worker_paths.items():
@ -282,7 +282,7 @@ def main() -> None:
def print_category( def print_category(
category_name: Optional[str], category_name: str | None,
elided_worker_paths: dict[tuple[str, str], EndpointDescription], elided_worker_paths: dict[tuple[str, str], EndpointDescription],
) -> None: ) -> None:
""" """

View File

@ -26,7 +26,7 @@ import hashlib
import hmac import hmac
import logging import logging
import sys import sys
from typing import Any, Callable, Optional from typing import Any, Callable
import requests import requests
import yaml import yaml
@ -54,7 +54,7 @@ def request_registration(
server_location: str, server_location: str,
shared_secret: str, shared_secret: str,
admin: bool = False, admin: bool = False,
user_type: Optional[str] = None, user_type: str | None = None,
_print: Callable[[str], None] = print, _print: Callable[[str], None] = print,
exit: Callable[[int], None] = sys.exit, exit: Callable[[int], None] = sys.exit,
exists_ok: bool = False, exists_ok: bool = False,
@ -123,13 +123,13 @@ def register_new_user(
password: str, password: str,
server_location: str, server_location: str,
shared_secret: str, shared_secret: str,
admin: Optional[bool], admin: bool | None,
user_type: Optional[str], user_type: str | None,
exists_ok: bool = False, exists_ok: bool = False,
) -> None: ) -> None:
if not user: if not user:
try: try:
default_user: Optional[str] = getpass.getuser() default_user: str | None = getpass.getuser()
except Exception: except Exception:
default_user = None default_user = None
@ -262,7 +262,7 @@ def main() -> None:
args = parser.parse_args() args = parser.parse_args()
config: Optional[dict[str, Any]] = None config: dict[str, Any] | None = None
if "config" in args and args.config: if "config" in args and args.config:
config = yaml.safe_load(args.config) config = yaml.safe_load(args.config)
@ -350,7 +350,7 @@ def _read_file(file_path: Any, config_path: str) -> str:
sys.exit(1) sys.exit(1)
def _find_client_listener(config: dict[str, Any]) -> Optional[str]: def _find_client_listener(config: dict[str, Any]) -> str | None:
# try to find a listener in the config. Returns a host:port pair # try to find a listener in the config. Returns a host:port pair
for listener in config.get("listeners", []): for listener in config.get("listeners", []):
if listener.get("type") != "http" or listener.get("tls", False): if listener.get("type") != "http" or listener.get("tls", False):

View File

@ -233,14 +233,14 @@ IGNORED_BACKGROUND_UPDATES = {
# Error returned by the run function. Used at the top-level part of the script to # Error returned by the run function. Used at the top-level part of the script to
# handle errors and return codes. # handle errors and return codes.
end_error: Optional[str] = None end_error: str | None = None
# The exec_info for the error, if any. If error is defined but not exec_info the script # The exec_info for the error, if any. If error is defined but not exec_info the script
# will show only the error message without the stacktrace, if exec_info is defined but # will show only the error message without the stacktrace, if exec_info is defined but
# not the error then the script will show nothing outside of what's printed in the run # not the error then the script will show nothing outside of what's printed in the run
# function. If both are defined, the script will print both the error and the stacktrace. # function. If both are defined, the script will print both the error and the stacktrace.
end_error_exec_info: Optional[ end_error_exec_info: tuple[type[BaseException], BaseException, TracebackType] | None = (
tuple[type[BaseException], BaseException, TracebackType] None
] = None )
R = TypeVar("R") R = TypeVar("R")
@ -485,7 +485,7 @@ class Porter:
def r( def r(
txn: LoggingTransaction, txn: LoggingTransaction,
) -> tuple[Optional[list[str]], list[tuple], list[tuple]]: ) -> tuple[list[str] | None, list[tuple], list[tuple]]:
forward_rows = [] forward_rows = []
backward_rows = [] backward_rows = []
if do_forward[0]: if do_forward[0]:
@ -502,7 +502,7 @@ class Porter:
if forward_rows or backward_rows: if forward_rows or backward_rows:
assert txn.description is not None assert txn.description is not None
headers: Optional[list[str]] = [ headers: list[str] | None = [
column[0] for column in txn.description column[0] for column in txn.description
] ]
else: else:
@ -1152,9 +1152,7 @@ class Porter:
return done, remaining + done return done, remaining + done
async def _setup_state_group_id_seq(self) -> None: async def _setup_state_group_id_seq(self) -> None:
curr_id: Optional[ curr_id: int | None = await self.sqlite_store.db_pool.simple_select_one_onecol(
int
] = await self.sqlite_store.db_pool.simple_select_one_onecol(
table="state_groups", keyvalues={}, retcol="MAX(id)", allow_none=True table="state_groups", keyvalues={}, retcol="MAX(id)", allow_none=True
) )
@ -1271,10 +1269,10 @@ class Porter:
await self.postgres_store.db_pool.runInteraction("_setup_%s" % (seq_name,), r) await self.postgres_store.db_pool.runInteraction("_setup_%s" % (seq_name,), r)
async def _pg_get_serial_sequence(self, table: str, column: str) -> Optional[str]: async def _pg_get_serial_sequence(self, table: str, column: str) -> str | None:
"""Returns the name of the postgres sequence associated with a column, or NULL.""" """Returns the name of the postgres sequence associated with a column, or NULL."""
def r(txn: LoggingTransaction) -> Optional[str]: def r(txn: LoggingTransaction) -> str | None:
txn.execute("SELECT pg_get_serial_sequence('%s', '%s')" % (table, column)) txn.execute("SELECT pg_get_serial_sequence('%s', '%s')" % (table, column))
result = txn.fetchone() result = txn.fetchone()
if not result: if not result:
@ -1286,9 +1284,9 @@ class Porter:
) )
async def _setup_auth_chain_sequence(self) -> None: async def _setup_auth_chain_sequence(self) -> None:
curr_chain_id: Optional[ curr_chain_id: (
int int | None
] = await self.sqlite_store.db_pool.simple_select_one_onecol( ) = await self.sqlite_store.db_pool.simple_select_one_onecol(
table="event_auth_chains", table="event_auth_chains",
keyvalues={}, keyvalues={},
retcol="MAX(chain_id)", retcol="MAX(chain_id)",

View File

@ -30,7 +30,7 @@ import signal
import subprocess import subprocess
import sys import sys
import time import time
from typing import Iterable, NoReturn, Optional, TextIO from typing import Iterable, NoReturn, TextIO
import yaml import yaml
@ -135,7 +135,7 @@ def start(pidfile: str, app: str, config_files: Iterable[str], daemonize: bool)
return False return False
def stop(pidfile: str, app: str) -> Optional[int]: def stop(pidfile: str, app: str) -> int | None:
"""Attempts to kill a synapse worker from the pidfile. """Attempts to kill a synapse worker from the pidfile.
Args: Args:
pidfile: path to file containing worker's pid pidfile: path to file containing worker's pid

View File

@ -18,7 +18,7 @@
# [This file includes modifications made by New Vector Limited] # [This file includes modifications made by New Vector Limited]
# #
# #
from typing import TYPE_CHECKING, Optional, Protocol from typing import TYPE_CHECKING, Protocol
from prometheus_client import Histogram from prometheus_client import Histogram
@ -51,7 +51,7 @@ class Auth(Protocol):
room_id: str, room_id: str,
requester: Requester, requester: Requester,
allow_departed_users: bool = False, allow_departed_users: bool = False,
) -> tuple[str, Optional[str]]: ) -> tuple[str, str | None]:
"""Check if the user is in the room, or was at some point. """Check if the user is in the room, or was at some point.
Args: Args:
room_id: The room to check. room_id: The room to check.
@ -190,7 +190,7 @@ class Auth(Protocol):
async def check_user_in_room_or_world_readable( async def check_user_in_room_or_world_readable(
self, room_id: str, requester: Requester, allow_departed_users: bool = False self, room_id: str, requester: Requester, allow_departed_users: bool = False
) -> tuple[str, Optional[str]]: ) -> tuple[str, str | None]:
"""Checks that the user is or was in the room or the room is world """Checks that the user is or was in the room or the room is world
readable. If it isn't then an exception is raised. readable. If it isn't then an exception is raised.

View File

@ -19,7 +19,7 @@
# #
# #
import logging import logging
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING
from netaddr import IPAddress from netaddr import IPAddress
@ -64,7 +64,7 @@ class BaseAuth:
room_id: str, room_id: str,
requester: Requester, requester: Requester,
allow_departed_users: bool = False, allow_departed_users: bool = False,
) -> tuple[str, Optional[str]]: ) -> tuple[str, str | None]:
"""Check if the user is in the room, or was at some point. """Check if the user is in the room, or was at some point.
Args: Args:
room_id: The room to check. room_id: The room to check.
@ -114,7 +114,7 @@ class BaseAuth:
@trace @trace
async def check_user_in_room_or_world_readable( async def check_user_in_room_or_world_readable(
self, room_id: str, requester: Requester, allow_departed_users: bool = False self, room_id: str, requester: Requester, allow_departed_users: bool = False
) -> tuple[str, Optional[str]]: ) -> tuple[str, str | None]:
"""Checks that the user is or was in the room or the room is world """Checks that the user is or was in the room or the room is world
readable. If it isn't then an exception is raised. readable. If it isn't then an exception is raised.
@ -294,7 +294,7 @@ class BaseAuth:
@cancellable @cancellable
async def get_appservice_user( async def get_appservice_user(
self, request: Request, access_token: str self, request: Request, access_token: str
) -> Optional[Requester]: ) -> Requester | None:
""" """
Given a request, reads the request parameters to determine: Given a request, reads the request parameters to determine:
- whether it's an application service that's making this request - whether it's an application service that's making this request

View File

@ -13,7 +13,7 @@
# #
# #
import logging import logging
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING
from urllib.parse import urlencode from urllib.parse import urlencode
from pydantic import ( from pydantic import (
@ -74,11 +74,11 @@ class ServerMetadata(BaseModel):
class IntrospectionResponse(BaseModel): class IntrospectionResponse(BaseModel):
retrieved_at_ms: StrictInt retrieved_at_ms: StrictInt
active: StrictBool active: StrictBool
scope: Optional[StrictStr] = None scope: StrictStr | None = None
username: Optional[StrictStr] = None username: StrictStr | None = None
sub: Optional[StrictStr] = None sub: StrictStr | None = None
device_id: Optional[StrictStr] = None device_id: StrictStr | None = None
expires_in: Optional[StrictInt] = None expires_in: StrictInt | None = None
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
def get_scope_set(self) -> set[str]: def get_scope_set(self) -> set[str]:

View File

@ -20,7 +20,7 @@
# #
import logging import logging
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Callable, Optional from typing import TYPE_CHECKING, Any, Callable
from urllib.parse import urlencode from urllib.parse import urlencode
from authlib.oauth2 import ClientAuth from authlib.oauth2 import ClientAuth
@ -102,25 +102,25 @@ class IntrospectionResult:
return [] return []
return scope_to_list(value) return scope_to_list(value)
def get_sub(self) -> Optional[str]: def get_sub(self) -> str | None:
value = self._inner.get("sub") value = self._inner.get("sub")
if not isinstance(value, str): if not isinstance(value, str):
return None return None
return value return value
def get_username(self) -> Optional[str]: def get_username(self) -> str | None:
value = self._inner.get("username") value = self._inner.get("username")
if not isinstance(value, str): if not isinstance(value, str):
return None return None
return value return value
def get_name(self) -> Optional[str]: def get_name(self) -> str | None:
value = self._inner.get("name") value = self._inner.get("name")
if not isinstance(value, str): if not isinstance(value, str):
return None return None
return value return value
def get_device_id(self) -> Optional[str]: def get_device_id(self) -> str | None:
value = self._inner.get("device_id") value = self._inner.get("device_id")
if value is not None and not isinstance(value, str): if value is not None and not isinstance(value, str):
raise AuthError( raise AuthError(
@ -174,7 +174,7 @@ class MSC3861DelegatedAuth(BaseAuth):
self._clock = hs.get_clock() self._clock = hs.get_clock()
self._http_client = hs.get_proxied_http_client() self._http_client = hs.get_proxied_http_client()
self._hostname = hs.hostname self._hostname = hs.hostname
self._admin_token: Callable[[], Optional[str]] = self._config.admin_token self._admin_token: Callable[[], str | None] = self._config.admin_token
self._force_tracing_for_users = hs.config.tracing.force_tracing_for_users self._force_tracing_for_users = hs.config.tracing.force_tracing_for_users
self._rust_http_client = HttpClient( self._rust_http_client = HttpClient(
@ -247,7 +247,7 @@ class MSC3861DelegatedAuth(BaseAuth):
metadata = await self._issuer_metadata.get() metadata = await self._issuer_metadata.get()
return metadata.issuer or self._config.issuer return metadata.issuer or self._config.issuer
async def account_management_url(self) -> Optional[str]: async def account_management_url(self) -> str | None:
""" """
Get the configured account management URL Get the configured account management URL

View File

@ -20,7 +20,7 @@
# #
import logging import logging
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING
from synapse.api.constants import LimitBlockingTypes, UserTypes from synapse.api.constants import LimitBlockingTypes, UserTypes
from synapse.api.errors import Codes, ResourceLimitError from synapse.api.errors import Codes, ResourceLimitError
@ -51,10 +51,10 @@ class AuthBlocking:
async def check_auth_blocking( async def check_auth_blocking(
self, self,
user_id: Optional[str] = None, user_id: str | None = None,
threepid: Optional[dict] = None, threepid: dict | None = None,
user_type: Optional[str] = None, user_type: str | None = None,
requester: Optional[Requester] = None, requester: Requester | None = None,
) -> None: ) -> None:
"""Checks if the user should be rejected for some external reason, """Checks if the user should be rejected for some external reason,
such as monthly active user limiting or global disable flag such as monthly active user limiting or global disable flag

View File

@ -26,7 +26,7 @@ import math
import typing import typing
from enum import Enum from enum import Enum
from http import HTTPStatus from http import HTTPStatus
from typing import Any, Optional, Union from typing import Any, Optional
from twisted.web import http from twisted.web import http
@ -164,9 +164,9 @@ class CodeMessageException(RuntimeError):
def __init__( def __init__(
self, self,
code: Union[int, HTTPStatus], code: int | HTTPStatus,
msg: str, msg: str,
headers: Optional[dict[str, str]] = None, headers: dict[str, str] | None = None,
): ):
super().__init__("%d: %s" % (code, msg)) super().__init__("%d: %s" % (code, msg))
@ -223,8 +223,8 @@ class SynapseError(CodeMessageException):
code: int, code: int,
msg: str, msg: str,
errcode: str = Codes.UNKNOWN, errcode: str = Codes.UNKNOWN,
additional_fields: Optional[dict] = None, additional_fields: dict | None = None,
headers: Optional[dict[str, str]] = None, headers: dict[str, str] | None = None,
): ):
"""Constructs a synapse error. """Constructs a synapse error.
@ -244,7 +244,7 @@ class SynapseError(CodeMessageException):
return cs_error(self.msg, self.errcode, **self._additional_fields) return cs_error(self.msg, self.errcode, **self._additional_fields)
@property @property
def debug_context(self) -> Optional[str]: def debug_context(self) -> str | None:
"""Override this to add debugging context that shouldn't be sent to clients.""" """Override this to add debugging context that shouldn't be sent to clients."""
return None return None
@ -276,7 +276,7 @@ class ProxiedRequestError(SynapseError):
code: int, code: int,
msg: str, msg: str,
errcode: str = Codes.UNKNOWN, errcode: str = Codes.UNKNOWN,
additional_fields: Optional[dict] = None, additional_fields: dict | None = None,
): ):
super().__init__(code, msg, errcode, additional_fields) super().__init__(code, msg, errcode, additional_fields)
@ -340,7 +340,7 @@ class FederationDeniedError(SynapseError):
destination: The destination which has been denied destination: The destination which has been denied
""" """
def __init__(self, destination: Optional[str]): def __init__(self, destination: str | None):
"""Raised by federation client or server to indicate that we are """Raised by federation client or server to indicate that we are
are deliberately not attempting to contact a given server because it is are deliberately not attempting to contact a given server because it is
not on our federation whitelist. not on our federation whitelist.
@ -399,7 +399,7 @@ class AuthError(SynapseError):
code: int, code: int,
msg: str, msg: str,
errcode: str = Codes.FORBIDDEN, errcode: str = Codes.FORBIDDEN,
additional_fields: Optional[dict] = None, additional_fields: dict | None = None,
): ):
super().__init__(code, msg, errcode, additional_fields) super().__init__(code, msg, errcode, additional_fields)
@ -432,7 +432,7 @@ class UnstableSpecAuthError(AuthError):
msg: str, msg: str,
errcode: str, errcode: str,
previous_errcode: str = Codes.FORBIDDEN, previous_errcode: str = Codes.FORBIDDEN,
additional_fields: Optional[dict] = None, additional_fields: dict | None = None,
): ):
self.previous_errcode = previous_errcode self.previous_errcode = previous_errcode
super().__init__(code, msg, errcode, additional_fields) super().__init__(code, msg, errcode, additional_fields)
@ -497,8 +497,8 @@ class ResourceLimitError(SynapseError):
code: int, code: int,
msg: str, msg: str,
errcode: str = Codes.RESOURCE_LIMIT_EXCEEDED, errcode: str = Codes.RESOURCE_LIMIT_EXCEEDED,
admin_contact: Optional[str] = None, admin_contact: str | None = None,
limit_type: Optional[str] = None, limit_type: str | None = None,
): ):
self.admin_contact = admin_contact self.admin_contact = admin_contact
self.limit_type = limit_type self.limit_type = limit_type
@ -542,7 +542,7 @@ class InvalidCaptchaError(SynapseError):
self, self,
code: int = 400, code: int = 400,
msg: str = "Invalid captcha.", msg: str = "Invalid captcha.",
error_url: Optional[str] = None, error_url: str | None = None,
errcode: str = Codes.CAPTCHA_INVALID, errcode: str = Codes.CAPTCHA_INVALID,
): ):
super().__init__(code, msg, errcode) super().__init__(code, msg, errcode)
@ -563,9 +563,9 @@ class LimitExceededError(SynapseError):
self, self,
limiter_name: str, limiter_name: str,
code: int = 429, code: int = 429,
retry_after_ms: Optional[int] = None, retry_after_ms: int | None = None,
errcode: str = Codes.LIMIT_EXCEEDED, errcode: str = Codes.LIMIT_EXCEEDED,
pause: Optional[float] = None, pause: float | None = None,
): ):
# Use HTTP header Retry-After to enable library-assisted retry handling. # Use HTTP header Retry-After to enable library-assisted retry handling.
headers = ( headers = (
@ -582,7 +582,7 @@ class LimitExceededError(SynapseError):
return cs_error(self.msg, self.errcode, retry_after_ms=self.retry_after_ms) return cs_error(self.msg, self.errcode, retry_after_ms=self.retry_after_ms)
@property @property
def debug_context(self) -> Optional[str]: def debug_context(self) -> str | None:
return self.limiter_name return self.limiter_name
@ -675,7 +675,7 @@ class RequestSendFailed(RuntimeError):
class UnredactedContentDeletedError(SynapseError): class UnredactedContentDeletedError(SynapseError):
def __init__(self, content_keep_ms: Optional[int] = None): def __init__(self, content_keep_ms: int | None = None):
super().__init__( super().__init__(
404, 404,
"The content for that event has already been erased from the database", "The content for that event has already been erased from the database",
@ -751,7 +751,7 @@ class FederationError(RuntimeError):
code: int, code: int,
reason: str, reason: str,
affected: str, affected: str,
source: Optional[str] = None, source: str | None = None,
): ):
if level not in ["FATAL", "ERROR", "WARN"]: if level not in ["FATAL", "ERROR", "WARN"]:
raise ValueError("Level is not valid: %s" % (level,)) raise ValueError("Level is not valid: %s" % (level,))
@ -786,7 +786,7 @@ class FederationPullAttemptBackoffError(RuntimeError):
""" """
def __init__( def __init__(
self, event_ids: "StrCollection", message: Optional[str], retry_after_ms: int self, event_ids: "StrCollection", message: str | None, retry_after_ms: int
): ):
event_ids = list(event_ids) event_ids = list(event_ids)

View File

@ -28,9 +28,7 @@ from typing import (
Collection, Collection,
Iterable, Iterable,
Mapping, Mapping,
Optional,
TypeVar, TypeVar,
Union,
) )
import jsonschema import jsonschema
@ -155,7 +153,7 @@ class Filtering:
self.DEFAULT_FILTER_COLLECTION = FilterCollection(hs, {}) self.DEFAULT_FILTER_COLLECTION = FilterCollection(hs, {})
async def get_user_filter( async def get_user_filter(
self, user_id: UserID, filter_id: Union[int, str] self, user_id: UserID, filter_id: int | str
) -> "FilterCollection": ) -> "FilterCollection":
result = await self.store.get_user_filter(user_id, filter_id) result = await self.store.get_user_filter(user_id, filter_id)
return FilterCollection(self._hs, result) return FilterCollection(self._hs, result)
@ -531,7 +529,7 @@ class Filter:
return newFilter return newFilter
def _matches_wildcard(actual_value: Optional[str], filter_value: str) -> bool: def _matches_wildcard(actual_value: str | None, filter_value: str) -> bool:
if filter_value.endswith("*") and isinstance(actual_value, str): if filter_value.endswith("*") and isinstance(actual_value, str):
type_prefix = filter_value[:-1] type_prefix = filter_value[:-1]
return actual_value.startswith(type_prefix) return actual_value.startswith(type_prefix)

View File

@ -19,7 +19,7 @@
# #
# #
from typing import Any, Optional from typing import Any
import attr import attr
@ -41,15 +41,13 @@ class UserDevicePresenceState:
""" """
user_id: str user_id: str
device_id: Optional[str] device_id: str | None
state: str state: str
last_active_ts: int last_active_ts: int
last_sync_ts: int last_sync_ts: int
@classmethod @classmethod
def default( def default(cls, user_id: str, device_id: str | None) -> "UserDevicePresenceState":
cls, user_id: str, device_id: Optional[str]
) -> "UserDevicePresenceState":
"""Returns a default presence state.""" """Returns a default presence state."""
return cls( return cls(
user_id=user_id, user_id=user_id,
@ -81,7 +79,7 @@ class UserPresenceState:
last_active_ts: int last_active_ts: int
last_federation_update_ts: int last_federation_update_ts: int
last_user_sync_ts: int last_user_sync_ts: int
status_msg: Optional[str] status_msg: str | None
currently_active: bool currently_active: bool
def as_dict(self) -> JsonDict: def as_dict(self) -> JsonDict:

View File

@ -102,9 +102,7 @@ class Ratelimiter:
self.clock.looping_call(self._prune_message_counts, 15 * 1000) self.clock.looping_call(self._prune_message_counts, 15 * 1000)
def _get_key( def _get_key(self, requester: Requester | None, key: Hashable | None) -> Hashable:
self, requester: Optional[Requester], key: Optional[Hashable]
) -> Hashable:
"""Use the requester's MXID as a fallback key if no key is provided.""" """Use the requester's MXID as a fallback key if no key is provided."""
if key is None: if key is None:
if not requester: if not requester:
@ -121,13 +119,13 @@ class Ratelimiter:
async def can_do_action( async def can_do_action(
self, self,
requester: Optional[Requester], requester: Requester | None,
key: Optional[Hashable] = None, key: Hashable | None = None,
rate_hz: Optional[float] = None, rate_hz: float | None = None,
burst_count: Optional[int] = None, burst_count: int | None = None,
update: bool = True, update: bool = True,
n_actions: int = 1, n_actions: int = 1,
_time_now_s: Optional[float] = None, _time_now_s: float | None = None,
) -> tuple[bool, float]: ) -> tuple[bool, float]:
"""Can the entity (e.g. user or IP address) perform the action? """Can the entity (e.g. user or IP address) perform the action?
@ -247,10 +245,10 @@ class Ratelimiter:
def record_action( def record_action(
self, self,
requester: Optional[Requester], requester: Requester | None,
key: Optional[Hashable] = None, key: Hashable | None = None,
n_actions: int = 1, n_actions: int = 1,
_time_now_s: Optional[float] = None, _time_now_s: float | None = None,
) -> None: ) -> None:
"""Record that an action(s) took place, even if they violate the rate limit. """Record that an action(s) took place, even if they violate the rate limit.
@ -332,14 +330,14 @@ class Ratelimiter:
async def ratelimit( async def ratelimit(
self, self,
requester: Optional[Requester], requester: Requester | None,
key: Optional[Hashable] = None, key: Hashable | None = None,
rate_hz: Optional[float] = None, rate_hz: float | None = None,
burst_count: Optional[int] = None, burst_count: int | None = None,
update: bool = True, update: bool = True,
n_actions: int = 1, n_actions: int = 1,
_time_now_s: Optional[float] = None, _time_now_s: float | None = None,
pause: Optional[float] = 0.5, pause: float | None = 0.5,
) -> None: ) -> None:
"""Checks if an action can be performed. If not, raises a LimitExceededError """Checks if an action can be performed. If not, raises a LimitExceededError
@ -396,7 +394,7 @@ class RequestRatelimiter:
store: DataStore, store: DataStore,
clock: Clock, clock: Clock,
rc_message: RatelimitSettings, rc_message: RatelimitSettings,
rc_admin_redaction: Optional[RatelimitSettings], rc_admin_redaction: RatelimitSettings | None,
): ):
self.store = store self.store = store
self.clock = clock self.clock = clock
@ -412,7 +410,7 @@ class RequestRatelimiter:
# Check whether ratelimiting room admin message redaction is enabled # Check whether ratelimiting room admin message redaction is enabled
# by the presence of rate limits in the config # by the presence of rate limits in the config
if rc_admin_redaction: if rc_admin_redaction:
self.admin_redaction_ratelimiter: Optional[Ratelimiter] = Ratelimiter( self.admin_redaction_ratelimiter: Ratelimiter | None = Ratelimiter(
store=self.store, store=self.store,
clock=self.clock, clock=self.clock,
cfg=rc_admin_redaction, cfg=rc_admin_redaction,

View File

@ -18,7 +18,7 @@
# #
# #
from typing import Callable, Optional from typing import Callable
import attr import attr
@ -503,7 +503,7 @@ class RoomVersionCapability:
"""An object which describes the unique attributes of a room version.""" """An object which describes the unique attributes of a room version."""
identifier: str # the identifier for this capability identifier: str # the identifier for this capability
preferred_version: Optional[RoomVersion] preferred_version: RoomVersion | None
support_check_lambda: Callable[[RoomVersion], bool] support_check_lambda: Callable[[RoomVersion], bool]

View File

@ -24,7 +24,6 @@
import hmac import hmac
import urllib.parse import urllib.parse
from hashlib import sha256 from hashlib import sha256
from typing import Optional
from urllib.parse import urlencode, urljoin from urllib.parse import urlencode, urljoin
from synapse.config import ConfigError from synapse.config import ConfigError
@ -75,7 +74,7 @@ class LoginSSORedirectURIBuilder:
self._public_baseurl = hs_config.server.public_baseurl self._public_baseurl = hs_config.server.public_baseurl
def build_login_sso_redirect_uri( def build_login_sso_redirect_uri(
self, *, idp_id: Optional[str], client_redirect_url: str self, *, idp_id: str | None, client_redirect_url: str
) -> str: ) -> str:
"""Build a `/login/sso/redirect` URI for the given identity provider. """Build a `/login/sso/redirect` URI for the given identity provider.

View File

@ -36,8 +36,6 @@ from typing import (
Awaitable, Awaitable,
Callable, Callable,
NoReturn, NoReturn,
Optional,
Union,
cast, cast,
) )
from wsgiref.simple_server import WSGIServer from wsgiref.simple_server import WSGIServer
@ -180,8 +178,8 @@ def start_worker_reactor(
def start_reactor( def start_reactor(
appname: str, appname: str,
soft_file_limit: int, soft_file_limit: int,
gc_thresholds: Optional[tuple[int, int, int]], gc_thresholds: tuple[int, int, int] | None,
pid_file: Optional[str], pid_file: str | None,
daemonize: bool, daemonize: bool,
print_pidfile: bool, print_pidfile: bool,
logger: logging.Logger, logger: logging.Logger,
@ -421,7 +419,7 @@ def listen_http(
root_resource: Resource, root_resource: Resource,
version_string: str, version_string: str,
max_request_body_size: int, max_request_body_size: int,
context_factory: Optional[IOpenSSLContextFactory], context_factory: IOpenSSLContextFactory | None,
reactor: ISynapseReactor = reactor, reactor: ISynapseReactor = reactor,
) -> list[Port]: ) -> list[Port]:
""" """
@ -564,9 +562,7 @@ def setup_sighup_handling() -> None:
if _already_setup_sighup_handling: if _already_setup_sighup_handling:
return return
previous_sighup_handler: Union[ previous_sighup_handler: Callable[[int, FrameType | None], Any] | int | None = None
Callable[[int, Optional[FrameType]], Any], int, None
] = None
# Set up the SIGHUP machinery. # Set up the SIGHUP machinery.
if hasattr(signal, "SIGHUP"): if hasattr(signal, "SIGHUP"):

View File

@ -24,7 +24,7 @@ import logging
import os import os
import sys import sys
import tempfile import tempfile
from typing import Mapping, Optional, Sequence from typing import Mapping, Sequence
from twisted.internet import defer, task from twisted.internet import defer, task
@ -136,7 +136,7 @@ class FileExfiltrationWriter(ExfiltrationWriter):
to a temporary directory. to a temporary directory.
""" """
def __init__(self, user_id: str, directory: Optional[str] = None): def __init__(self, user_id: str, directory: str | None = None):
self.user_id = user_id self.user_id = user_id
if directory: if directory:
@ -291,7 +291,7 @@ def load_config(argv_options: list[str]) -> tuple[HomeServerConfig, argparse.Nam
def create_homeserver( def create_homeserver(
config: HomeServerConfig, config: HomeServerConfig,
reactor: Optional[ISynapseReactor] = None, reactor: ISynapseReactor | None = None,
) -> AdminCmdServer: ) -> AdminCmdServer:
""" """
Create a homeserver instance for the Synapse admin command process. Create a homeserver instance for the Synapse admin command process.

View File

@ -26,7 +26,7 @@ import os
import signal import signal
import sys import sys
from types import FrameType from types import FrameType
from typing import Any, Callable, Optional from typing import Any, Callable
from twisted.internet.main import installReactor from twisted.internet.main import installReactor
@ -172,7 +172,7 @@ def main() -> None:
# Install signal handlers to propagate signals to all our children, so that they # Install signal handlers to propagate signals to all our children, so that they
# shut down cleanly. This also inhibits our own exit, but that's good: we want to # shut down cleanly. This also inhibits our own exit, but that's good: we want to
# wait until the children have exited. # wait until the children have exited.
def handle_signal(signum: int, frame: Optional[FrameType]) -> None: def handle_signal(signum: int, frame: FrameType | None) -> None:
print( print(
f"complement_fork_starter: Caught signal {signum}. Stopping children.", f"complement_fork_starter: Caught signal {signum}. Stopping children.",
file=sys.stderr, file=sys.stderr,

View File

@ -21,7 +21,6 @@
# #
import logging import logging
import sys import sys
from typing import Optional
from twisted.web.resource import Resource from twisted.web.resource import Resource
@ -336,7 +335,7 @@ def load_config(argv_options: list[str]) -> HomeServerConfig:
def create_homeserver( def create_homeserver(
config: HomeServerConfig, config: HomeServerConfig,
reactor: Optional[ISynapseReactor] = None, reactor: ISynapseReactor | None = None,
) -> GenericWorkerServer: ) -> GenericWorkerServer:
""" """
Create a homeserver instance for the Synapse worker process. Create a homeserver instance for the Synapse worker process.

View File

@ -22,7 +22,7 @@
import logging import logging
import os import os
import sys import sys
from typing import Iterable, Optional from typing import Iterable
from twisted.internet.tcp import Port from twisted.internet.tcp import Port
from twisted.web.resource import EncodingResourceWrapper, Resource from twisted.web.resource import EncodingResourceWrapper, Resource
@ -350,7 +350,7 @@ def load_or_generate_config(argv_options: list[str]) -> HomeServerConfig:
def create_homeserver( def create_homeserver(
config: HomeServerConfig, config: HomeServerConfig,
reactor: Optional[ISynapseReactor] = None, reactor: ISynapseReactor | None = None,
) -> SynapseHomeServer: ) -> SynapseHomeServer:
""" """
Create a homeserver instance for the Synapse main process. Create a homeserver instance for the Synapse main process.

View File

@ -26,7 +26,6 @@ from enum import Enum
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Iterable, Iterable,
Optional,
Pattern, Pattern,
Sequence, Sequence,
cast, cast,
@ -95,12 +94,12 @@ class ApplicationService:
token: str, token: str,
id: str, id: str,
sender: UserID, sender: UserID,
url: Optional[str] = None, url: str | None = None,
namespaces: Optional[JsonDict] = None, namespaces: JsonDict | None = None,
hs_token: Optional[str] = None, hs_token: str | None = None,
protocols: Optional[Iterable[str]] = None, protocols: Iterable[str] | None = None,
rate_limited: bool = True, rate_limited: bool = True,
ip_range_whitelist: Optional[IPSet] = None, ip_range_whitelist: IPSet | None = None,
supports_ephemeral: bool = False, supports_ephemeral: bool = False,
msc3202_transaction_extensions: bool = False, msc3202_transaction_extensions: bool = False,
msc4190_device_management: bool = False, msc4190_device_management: bool = False,
@ -142,7 +141,7 @@ class ApplicationService:
self.rate_limited = rate_limited self.rate_limited = rate_limited
def _check_namespaces( def _check_namespaces(
self, namespaces: Optional[JsonDict] self, namespaces: JsonDict | None
) -> dict[str, list[Namespace]]: ) -> dict[str, list[Namespace]]:
# Sanity check that it is of the form: # Sanity check that it is of the form:
# { # {
@ -179,9 +178,7 @@ class ApplicationService:
return result return result
def _matches_regex( def _matches_regex(self, namespace_key: str, test_string: str) -> Namespace | None:
self, namespace_key: str, test_string: str
) -> Optional[Namespace]:
for namespace in self.namespaces[namespace_key]: for namespace in self.namespaces[namespace_key]:
if namespace.regex.match(test_string): if namespace.regex.match(test_string):
return namespace return namespace

View File

@ -25,10 +25,8 @@ from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Iterable, Iterable,
Mapping, Mapping,
Optional,
Sequence, Sequence,
TypeVar, TypeVar,
Union,
) )
from prometheus_client import Counter from prometheus_client import Counter
@ -222,7 +220,7 @@ class ApplicationServiceApi(SimpleHttpClient):
assert service.hs_token is not None assert service.hs_token is not None
try: try:
args: Mapping[bytes, Union[list[bytes], str]] = fields args: Mapping[bytes, list[bytes] | str] = fields
if self.config.use_appservice_legacy_authorization: if self.config.use_appservice_legacy_authorization:
args = { args = {
**fields, **fields,
@ -258,11 +256,11 @@ class ApplicationServiceApi(SimpleHttpClient):
async def get_3pe_protocol( async def get_3pe_protocol(
self, service: "ApplicationService", protocol: str self, service: "ApplicationService", protocol: str
) -> Optional[JsonDict]: ) -> JsonDict | None:
if service.url is None: if service.url is None:
return {} return {}
async def _get() -> Optional[JsonDict]: async def _get() -> JsonDict | None:
# This is required by the configuration. # This is required by the configuration.
assert service.hs_token is not None assert service.hs_token is not None
try: try:
@ -300,7 +298,7 @@ class ApplicationServiceApi(SimpleHttpClient):
key = (service.id, protocol) key = (service.id, protocol)
return await self.protocol_meta_cache.wrap(key, _get) return await self.protocol_meta_cache.wrap(key, _get)
async def ping(self, service: "ApplicationService", txn_id: Optional[str]) -> None: async def ping(self, service: "ApplicationService", txn_id: str | None) -> None:
# The caller should check that url is set # The caller should check that url is set
assert service.url is not None, "ping called without URL being set" assert service.url is not None, "ping called without URL being set"
@ -322,7 +320,7 @@ class ApplicationServiceApi(SimpleHttpClient):
one_time_keys_count: TransactionOneTimeKeysCount, one_time_keys_count: TransactionOneTimeKeysCount,
unused_fallback_keys: TransactionUnusedFallbackKeys, unused_fallback_keys: TransactionUnusedFallbackKeys,
device_list_summary: DeviceListUpdates, device_list_summary: DeviceListUpdates,
txn_id: Optional[int] = None, txn_id: int | None = None,
) -> bool: ) -> bool:
""" """
Push data to an application service. Push data to an application service.

View File

@ -62,7 +62,6 @@ from typing import (
Callable, Callable,
Collection, Collection,
Iterable, Iterable,
Optional,
Sequence, Sequence,
) )
@ -123,10 +122,10 @@ class ApplicationServiceScheduler:
def enqueue_for_appservice( def enqueue_for_appservice(
self, self,
appservice: ApplicationService, appservice: ApplicationService,
events: Optional[Collection[EventBase]] = None, events: Collection[EventBase] | None = None,
ephemeral: Optional[Collection[JsonMapping]] = None, ephemeral: Collection[JsonMapping] | None = None,
to_device_messages: Optional[Collection[JsonMapping]] = None, to_device_messages: Collection[JsonMapping] | None = None,
device_list_summary: Optional[DeviceListUpdates] = None, device_list_summary: DeviceListUpdates | None = None,
) -> None: ) -> None:
""" """
Enqueue some data to be sent off to an application service. Enqueue some data to be sent off to an application service.
@ -260,8 +259,8 @@ class _ServiceQueuer:
): ):
return return
one_time_keys_count: Optional[TransactionOneTimeKeysCount] = None one_time_keys_count: TransactionOneTimeKeysCount | None = None
unused_fallback_keys: Optional[TransactionUnusedFallbackKeys] = None unused_fallback_keys: TransactionUnusedFallbackKeys | None = None
if ( if (
self._msc3202_transaction_extensions_enabled self._msc3202_transaction_extensions_enabled
@ -369,11 +368,11 @@ class _TransactionController:
self, self,
service: ApplicationService, service: ApplicationService,
events: Sequence[EventBase], events: Sequence[EventBase],
ephemeral: Optional[list[JsonMapping]] = None, ephemeral: list[JsonMapping] | None = None,
to_device_messages: Optional[list[JsonMapping]] = None, to_device_messages: list[JsonMapping] | None = None,
one_time_keys_count: Optional[TransactionOneTimeKeysCount] = None, one_time_keys_count: TransactionOneTimeKeysCount | None = None,
unused_fallback_keys: Optional[TransactionUnusedFallbackKeys] = None, unused_fallback_keys: TransactionUnusedFallbackKeys | None = None,
device_list_summary: Optional[DeviceListUpdates] = None, device_list_summary: DeviceListUpdates | None = None,
) -> None: ) -> None:
""" """
Create a transaction with the given data and send to the provided Create a transaction with the given data and send to the provided
@ -504,7 +503,7 @@ class _Recoverer:
self.service = service self.service = service
self.callback = callback self.callback = callback
self.backoff_counter = 1 self.backoff_counter = 1
self.scheduled_recovery: Optional[IDelayedCall] = None self.scheduled_recovery: IDelayedCall | None = None
def recover(self) -> None: def recover(self) -> None:
delay = 2**self.backoff_counter delay = 2**self.backoff_counter

View File

@ -36,9 +36,7 @@ from typing import (
Iterable, Iterable,
Iterator, Iterator,
MutableMapping, MutableMapping,
Optional,
TypeVar, TypeVar,
Union,
) )
import attr import attr
@ -60,7 +58,7 @@ class ConfigError(Exception):
the problem lies. the problem lies.
""" """
def __init__(self, msg: str, path: Optional[StrSequence] = None): def __init__(self, msg: str, path: StrSequence | None = None):
self.msg = msg self.msg = msg
self.path = path self.path = path
@ -175,7 +173,7 @@ class Config:
) )
@staticmethod @staticmethod
def parse_size(value: Union[str, int]) -> int: def parse_size(value: str | int) -> int:
"""Interpret `value` as a number of bytes. """Interpret `value` as a number of bytes.
If an integer is provided it is treated as bytes and is unchanged. If an integer is provided it is treated as bytes and is unchanged.
@ -202,7 +200,7 @@ class Config:
raise TypeError(f"Bad byte size {value!r}") raise TypeError(f"Bad byte size {value!r}")
@staticmethod @staticmethod
def parse_duration(value: Union[str, int]) -> int: def parse_duration(value: str | int) -> int:
"""Convert a duration as a string or integer to a number of milliseconds. """Convert a duration as a string or integer to a number of milliseconds.
If an integer is provided it is treated as milliseconds and is unchanged. If an integer is provided it is treated as milliseconds and is unchanged.
@ -270,7 +268,7 @@ class Config:
return path_exists(file_path) return path_exists(file_path)
@classmethod @classmethod
def check_file(cls, file_path: Optional[str], config_name: str) -> str: def check_file(cls, file_path: str | None, config_name: str) -> str:
if file_path is None: if file_path is None:
raise ConfigError("Missing config for %s." % (config_name,)) raise ConfigError("Missing config for %s." % (config_name,))
try: try:
@ -318,7 +316,7 @@ class Config:
def read_templates( def read_templates(
self, self,
filenames: list[str], filenames: list[str],
custom_template_directories: Optional[Iterable[str]] = None, custom_template_directories: Iterable[str] | None = None,
) -> list[jinja2.Template]: ) -> list[jinja2.Template]:
"""Load a list of template files from disk using the given variables. """Load a list of template files from disk using the given variables.
@ -465,11 +463,11 @@ class RootConfig:
data_dir_path: str, data_dir_path: str,
server_name: str, server_name: str,
generate_secrets: bool = False, generate_secrets: bool = False,
report_stats: Optional[bool] = None, report_stats: bool | None = None,
open_private_ports: bool = False, open_private_ports: bool = False,
listeners: Optional[list[dict]] = None, listeners: list[dict] | None = None,
tls_certificate_path: Optional[str] = None, tls_certificate_path: str | None = None,
tls_private_key_path: Optional[str] = None, tls_private_key_path: str | None = None,
) -> str: ) -> str:
""" """
Build a default configuration file Build a default configuration file
@ -655,7 +653,7 @@ class RootConfig:
@classmethod @classmethod
def load_or_generate_config( def load_or_generate_config(
cls: type[TRootConfig], description: str, argv_options: list[str] cls: type[TRootConfig], description: str, argv_options: list[str]
) -> Optional[TRootConfig]: ) -> TRootConfig | None:
"""Parse the commandline and config files """Parse the commandline and config files
Supports generation of config files, so is used for the main homeserver app. Supports generation of config files, so is used for the main homeserver app.
@ -898,7 +896,7 @@ class RootConfig:
:returns: the previous config object, which no longer has a reference to this :returns: the previous config object, which no longer has a reference to this
RootConfig. RootConfig.
""" """
existing_config: Optional[Config] = getattr(self, section_name, None) existing_config: Config | None = getattr(self, section_name, None)
if existing_config is None: if existing_config is None:
raise ValueError(f"Unknown config section '{section_name}'") raise ValueError(f"Unknown config section '{section_name}'")
logger.info("Reloading config section '%s'", section_name) logger.info("Reloading config section '%s'", section_name)

View File

@ -6,9 +6,7 @@ from typing import (
Iterator, Iterator,
Literal, Literal,
MutableMapping, MutableMapping,
Optional,
TypeVar, TypeVar,
Union,
overload, overload,
) )
@ -64,7 +62,7 @@ from synapse.config import ( # noqa: F401
from synapse.types import StrSequence from synapse.types import StrSequence
class ConfigError(Exception): class ConfigError(Exception):
def __init__(self, msg: str, path: Optional[StrSequence] = None): def __init__(self, msg: str, path: StrSequence | None = None):
self.msg = msg self.msg = msg
self.path = path self.path = path
@ -146,16 +144,16 @@ class RootConfig:
data_dir_path: str, data_dir_path: str,
server_name: str, server_name: str,
generate_secrets: bool = ..., generate_secrets: bool = ...,
report_stats: Optional[bool] = ..., report_stats: bool | None = ...,
open_private_ports: bool = ..., open_private_ports: bool = ...,
listeners: Optional[Any] = ..., listeners: Any | None = ...,
tls_certificate_path: Optional[str] = ..., tls_certificate_path: str | None = ...,
tls_private_key_path: Optional[str] = ..., tls_private_key_path: str | None = ...,
) -> str: ... ) -> str: ...
@classmethod @classmethod
def load_or_generate_config( def load_or_generate_config(
cls: type[TRootConfig], description: str, argv_options: list[str] cls: type[TRootConfig], description: str, argv_options: list[str]
) -> Optional[TRootConfig]: ... ) -> TRootConfig | None: ...
@classmethod @classmethod
def load_config( def load_config(
cls: type[TRootConfig], description: str, argv_options: list[str] cls: type[TRootConfig], description: str, argv_options: list[str]
@ -183,11 +181,11 @@ class Config:
default_template_dir: str default_template_dir: str
def __init__(self, root_config: RootConfig = ...) -> None: ... def __init__(self, root_config: RootConfig = ...) -> None: ...
@staticmethod @staticmethod
def parse_size(value: Union[str, int]) -> int: ... def parse_size(value: str | int) -> int: ...
@staticmethod @staticmethod
def parse_duration(value: Union[str, int]) -> int: ... def parse_duration(value: str | int) -> int: ...
@staticmethod @staticmethod
def abspath(file_path: Optional[str]) -> str: ... def abspath(file_path: str | None) -> str: ...
@classmethod @classmethod
def path_exists(cls, file_path: str) -> bool: ... def path_exists(cls, file_path: str) -> bool: ...
@classmethod @classmethod
@ -200,7 +198,7 @@ class Config:
def read_templates( def read_templates(
self, self,
filenames: list[str], filenames: list[str],
custom_template_directories: Optional[Iterable[str]] = None, custom_template_directories: Iterable[str] | None = None,
) -> list[jinja2.Template]: ... ) -> list[jinja2.Template]: ...
def read_config_files(config_files: Iterable[str]) -> dict[str, Any]: ... def read_config_files(config_files: Iterable[str]) -> dict[str, Any]: ...

View File

@ -20,7 +20,7 @@
# #
import logging import logging
from typing import Any, Iterable, Optional from typing import Any, Iterable
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.config._base import Config, ConfigError from synapse.config._base import Config, ConfigError
@ -46,7 +46,7 @@ class ApiConfig(Config):
def _get_prejoin_state_entries( def _get_prejoin_state_entries(
self, config: JsonDict self, config: JsonDict
) -> Iterable[tuple[str, Optional[str]]]: ) -> Iterable[tuple[str, str | None]]:
"""Get the event types and state keys to include in the prejoin state.""" """Get the event types and state keys to include in the prejoin state."""
room_prejoin_state_config = config.get("room_prejoin_state") or {} room_prejoin_state_config = config.get("room_prejoin_state") or {}

View File

@ -23,7 +23,7 @@ import logging
import os import os
import re import re
import threading import threading
from typing import Any, Callable, Mapping, Optional from typing import Any, Callable, Mapping
import attr import attr
@ -53,7 +53,7 @@ class CacheProperties:
default_factor_size: float = float( default_factor_size: float = float(
os.environ.get(_CACHE_PREFIX, _DEFAULT_FACTOR_SIZE) os.environ.get(_CACHE_PREFIX, _DEFAULT_FACTOR_SIZE)
) )
resize_all_caches_func: Optional[Callable[[], None]] = None resize_all_caches_func: Callable[[], None] | None = None
properties = CacheProperties() properties = CacheProperties()
@ -107,7 +107,7 @@ class CacheConfig(Config):
cache_factors: dict[str, float] cache_factors: dict[str, float]
global_factor: float global_factor: float
track_memory_usage: bool track_memory_usage: bool
expiry_time_msec: Optional[int] expiry_time_msec: int | None
sync_response_cache_duration: int sync_response_cache_duration: int
@staticmethod @staticmethod

View File

@ -20,7 +20,7 @@
# #
# #
from typing import Any, Optional from typing import Any
from synapse.config.sso import SsoAttributeRequirement from synapse.config.sso import SsoAttributeRequirement
from synapse.types import JsonDict from synapse.types import JsonDict
@ -49,7 +49,7 @@ class CasConfig(Config):
# TODO Update this to a _synapse URL. # TODO Update this to a _synapse URL.
public_baseurl = self.root.server.public_baseurl public_baseurl = self.root.server.public_baseurl
self.cas_service_url: Optional[str] = ( self.cas_service_url: str | None = (
public_baseurl + "_matrix/client/r0/login/cas/ticket" public_baseurl + "_matrix/client/r0/login/cas/ticket"
) )

View File

@ -19,7 +19,7 @@
# #
from os import path from os import path
from typing import Any, Optional from typing import Any
from synapse.config import ConfigError from synapse.config import ConfigError
from synapse.types import JsonDict from synapse.types import JsonDict
@ -33,11 +33,11 @@ class ConsentConfig(Config):
def __init__(self, *args: Any): def __init__(self, *args: Any):
super().__init__(*args) super().__init__(*args)
self.user_consent_version: Optional[str] = None self.user_consent_version: str | None = None
self.user_consent_template_dir: Optional[str] = None self.user_consent_template_dir: str | None = None
self.user_consent_server_notice_content: Optional[JsonDict] = None self.user_consent_server_notice_content: JsonDict | None = None
self.user_consent_server_notice_to_guests = False self.user_consent_server_notice_to_guests = False
self.block_events_without_consent_error: Optional[str] = None self.block_events_without_consent_error: str | None = None
self.user_consent_at_registration = False self.user_consent_at_registration = False
self.user_consent_policy_name = "Privacy Policy" self.user_consent_policy_name = "Privacy Policy"

View File

@ -59,7 +59,7 @@ class ClientAuthMethod(enum.Enum):
PRIVATE_KEY_JWT = "private_key_jwt" PRIVATE_KEY_JWT = "private_key_jwt"
def _parse_jwks(jwks: Optional[JsonDict]) -> Optional["JsonWebKey"]: def _parse_jwks(jwks: JsonDict | None) -> Optional["JsonWebKey"]:
"""A helper function to parse a JWK dict into a JsonWebKey.""" """A helper function to parse a JWK dict into a JsonWebKey."""
if jwks is None: if jwks is None:
@ -71,7 +71,7 @@ def _parse_jwks(jwks: Optional[JsonDict]) -> Optional["JsonWebKey"]:
def _check_client_secret( def _check_client_secret(
instance: "MSC3861", _attribute: attr.Attribute, _value: Optional[str] instance: "MSC3861", _attribute: attr.Attribute, _value: str | None
) -> None: ) -> None:
if instance._client_secret and instance._client_secret_path: if instance._client_secret and instance._client_secret_path:
raise ConfigError( raise ConfigError(
@ -88,7 +88,7 @@ def _check_client_secret(
def _check_admin_token( def _check_admin_token(
instance: "MSC3861", _attribute: attr.Attribute, _value: Optional[str] instance: "MSC3861", _attribute: attr.Attribute, _value: str | None
) -> None: ) -> None:
if instance._admin_token and instance._admin_token_path: if instance._admin_token and instance._admin_token_path:
raise ConfigError( raise ConfigError(
@ -124,7 +124,7 @@ class MSC3861:
issuer: str = attr.ib(default="", validator=attr.validators.instance_of(str)) issuer: str = attr.ib(default="", validator=attr.validators.instance_of(str))
"""The URL of the OIDC Provider.""" """The URL of the OIDC Provider."""
issuer_metadata: Optional[JsonDict] = attr.ib(default=None) issuer_metadata: JsonDict | None = attr.ib(default=None)
"""The issuer metadata to use, otherwise discovered from /.well-known/openid-configuration as per MSC2965.""" """The issuer metadata to use, otherwise discovered from /.well-known/openid-configuration as per MSC2965."""
client_id: str = attr.ib( client_id: str = attr.ib(
@ -138,7 +138,7 @@ class MSC3861:
) )
"""The auth method used when calling the introspection endpoint.""" """The auth method used when calling the introspection endpoint."""
_client_secret: Optional[str] = attr.ib( _client_secret: str | None = attr.ib(
default=None, default=None,
validator=[ validator=[
attr.validators.optional(attr.validators.instance_of(str)), attr.validators.optional(attr.validators.instance_of(str)),
@ -150,7 +150,7 @@ class MSC3861:
when using any of the client_secret_* client auth methods. when using any of the client_secret_* client auth methods.
""" """
_client_secret_path: Optional[str] = attr.ib( _client_secret_path: str | None = attr.ib(
default=None, default=None,
validator=[ validator=[
attr.validators.optional(attr.validators.instance_of(str)), attr.validators.optional(attr.validators.instance_of(str)),
@ -196,19 +196,19 @@ class MSC3861:
("experimental", "msc3861", "client_auth_method"), ("experimental", "msc3861", "client_auth_method"),
) )
introspection_endpoint: Optional[str] = attr.ib( introspection_endpoint: str | None = attr.ib(
default=None, default=None,
validator=attr.validators.optional(attr.validators.instance_of(str)), validator=attr.validators.optional(attr.validators.instance_of(str)),
) )
"""The URL of the introspection endpoint used to validate access tokens.""" """The URL of the introspection endpoint used to validate access tokens."""
account_management_url: Optional[str] = attr.ib( account_management_url: str | None = attr.ib(
default=None, default=None,
validator=attr.validators.optional(attr.validators.instance_of(str)), validator=attr.validators.optional(attr.validators.instance_of(str)),
) )
"""The URL of the My Account page on the OIDC Provider as per MSC2965.""" """The URL of the My Account page on the OIDC Provider as per MSC2965."""
_admin_token: Optional[str] = attr.ib( _admin_token: str | None = attr.ib(
default=None, default=None,
validator=[ validator=[
attr.validators.optional(attr.validators.instance_of(str)), attr.validators.optional(attr.validators.instance_of(str)),
@ -220,7 +220,7 @@ class MSC3861:
This is used by the OIDC provider, to make admin calls to Synapse. This is used by the OIDC provider, to make admin calls to Synapse.
""" """
_admin_token_path: Optional[str] = attr.ib( _admin_token_path: str | None = attr.ib(
default=None, default=None,
validator=[ validator=[
attr.validators.optional(attr.validators.instance_of(str)), attr.validators.optional(attr.validators.instance_of(str)),
@ -232,7 +232,7 @@ class MSC3861:
external file. external file.
""" """
def client_secret(self) -> Optional[str]: def client_secret(self) -> str | None:
"""Returns the secret given via `client_secret` or `client_secret_path`.""" """Returns the secret given via `client_secret` or `client_secret_path`."""
if self._client_secret_path: if self._client_secret_path:
return read_secret_from_file_once( return read_secret_from_file_once(
@ -241,7 +241,7 @@ class MSC3861:
) )
return self._client_secret return self._client_secret
def admin_token(self) -> Optional[str]: def admin_token(self) -> str | None:
"""Returns the admin token given via `admin_token` or `admin_token_path`.""" """Returns the admin token given via `admin_token` or `admin_token_path`."""
if self._admin_token_path: if self._admin_token_path:
return read_secret_from_file_once( return read_secret_from_file_once(
@ -526,7 +526,7 @@ class ExperimentalConfig(Config):
# MSC4108: Mechanism to allow OIDC sign in and E2EE set up via QR code # MSC4108: Mechanism to allow OIDC sign in and E2EE set up via QR code
self.msc4108_enabled = experimental.get("msc4108_enabled", False) self.msc4108_enabled = experimental.get("msc4108_enabled", False)
self.msc4108_delegation_endpoint: Optional[str] = experimental.get( self.msc4108_delegation_endpoint: str | None = experimental.get(
"msc4108_delegation_endpoint", None "msc4108_delegation_endpoint", None
) )

View File

@ -18,7 +18,7 @@
# [This file includes modifications made by New Vector Limited] # [This file includes modifications made by New Vector Limited]
# #
# #
from typing import Any, Optional from typing import Any
from synapse.config._base import Config from synapse.config._base import Config
from synapse.config._util import validate_config from synapse.config._util import validate_config
@ -32,7 +32,7 @@ class FederationConfig(Config):
federation_config = config.setdefault("federation", {}) federation_config = config.setdefault("federation", {})
# FIXME: federation_domain_whitelist needs sytests # FIXME: federation_domain_whitelist needs sytests
self.federation_domain_whitelist: Optional[dict] = None self.federation_domain_whitelist: dict | None = None
federation_domain_whitelist = config.get("federation_domain_whitelist", None) federation_domain_whitelist = config.get("federation_domain_whitelist", None)
if federation_domain_whitelist is not None: if federation_domain_whitelist is not None:

View File

@ -23,7 +23,7 @@
import hashlib import hashlib
import logging import logging
import os import os
from typing import TYPE_CHECKING, Any, Iterator, Optional from typing import TYPE_CHECKING, Any, Iterator
import attr import attr
import jsonschema import jsonschema
@ -110,7 +110,7 @@ class TrustedKeyServer:
server_name: str server_name: str
# map from key id to key object, or None to disable signature verification. # map from key id to key object, or None to disable signature verification.
verify_keys: Optional[dict[str, VerifyKey]] = None verify_keys: dict[str, VerifyKey] | None = None
class KeyConfig(Config): class KeyConfig(Config):
@ -219,7 +219,7 @@ class KeyConfig(Config):
if form_secret_path: if form_secret_path:
if form_secret: if form_secret:
raise ConfigError(CONFLICTING_FORM_SECRET_OPTS_ERROR) raise ConfigError(CONFLICTING_FORM_SECRET_OPTS_ERROR)
self.form_secret: Optional[str] = read_file( self.form_secret: str | None = read_file(
form_secret_path, ("form_secret_path",) form_secret_path, ("form_secret_path",)
).strip() ).strip()
else: else:
@ -279,7 +279,7 @@ class KeyConfig(Config):
raise ConfigError("Error reading %s: %s" % (name, str(e))) raise ConfigError("Error reading %s: %s" % (name, str(e)))
def read_old_signing_keys( def read_old_signing_keys(
self, old_signing_keys: Optional[JsonDict] self, old_signing_keys: JsonDict | None
) -> dict[str, "VerifyKeyWithExpiry"]: ) -> dict[str, "VerifyKeyWithExpiry"]:
if old_signing_keys is None: if old_signing_keys is None:
return {} return {}
@ -408,7 +408,7 @@ def _parse_key_servers(
server_name = server["server_name"] server_name = server["server_name"]
result = TrustedKeyServer(server_name=server_name) result = TrustedKeyServer(server_name=server_name)
verify_keys: Optional[dict[str, str]] = server.get("verify_keys") verify_keys: dict[str, str] | None = server.get("verify_keys")
if verify_keys is not None: if verify_keys is not None:
result.verify_keys = {} result.verify_keys = {}
for key_id, key_base64 in verify_keys.items(): for key_id, key_base64 in verify_keys.items():

View File

@ -26,7 +26,7 @@ import os
import sys import sys
import threading import threading
from string import Template from string import Template
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any
import yaml import yaml
from zope.interface import implementer from zope.interface import implementer
@ -280,7 +280,7 @@ def one_time_logging_setup(*, logBeginner: LogBeginner = globalLogBeginner) -> N
def _setup_stdlib_logging( def _setup_stdlib_logging(
config: "HomeServerConfig", log_config_path: Optional[str] config: "HomeServerConfig", log_config_path: str | None
) -> None: ) -> None:
""" """
Set up Python standard library logging. Set up Python standard library logging.
@ -327,7 +327,7 @@ def _load_logging_config(log_config_path: str) -> None:
reset_logging_config() reset_logging_config()
def _reload_logging_config(log_config_path: Optional[str]) -> None: def _reload_logging_config(log_config_path: str | None) -> None:
""" """
Reload the log configuration from the file and apply it. Reload the log configuration from the file and apply it.
""" """

View File

@ -13,7 +13,7 @@
# #
# #
from typing import Any, Optional from typing import Any
from pydantic import ( from pydantic import (
AnyHttpUrl, AnyHttpUrl,
@ -36,8 +36,8 @@ from ._base import Config, ConfigError, RootConfig
class MasConfigModel(ParseModel): class MasConfigModel(ParseModel):
enabled: StrictBool = False enabled: StrictBool = False
endpoint: AnyHttpUrl = AnyHttpUrl("http://localhost:8080") endpoint: AnyHttpUrl = AnyHttpUrl("http://localhost:8080")
secret: Optional[StrictStr] = Field(default=None) secret: StrictStr | None = Field(default=None)
secret_path: Optional[FilePath] = Field(default=None) secret_path: FilePath | None = Field(default=None)
@model_validator(mode="after") @model_validator(mode="after")
def verify_secret(self) -> Self: def verify_secret(self) -> Self:

View File

@ -15,7 +15,7 @@
# #
# #
from typing import Any, Optional from typing import Any
from pydantic import Field, StrictStr, ValidationError, model_validator from pydantic import Field, StrictStr, ValidationError, model_validator
from typing_extensions import Self from typing_extensions import Self
@ -29,7 +29,7 @@ from ._base import Config, ConfigError
class TransportConfigModel(ParseModel): class TransportConfigModel(ParseModel):
type: StrictStr type: StrictStr
livekit_service_url: Optional[StrictStr] = Field(default=None) livekit_service_url: StrictStr | None = Field(default=None)
"""An optional livekit service URL. Only required if type is "livekit".""" """An optional livekit service URL. Only required if type is "livekit"."""
@model_validator(mode="after") @model_validator(mode="after")

View File

@ -20,7 +20,7 @@
# #
# #
from typing import Any, Optional from typing import Any
import attr import attr
@ -75,7 +75,7 @@ class MetricsConfig(Config):
) )
def generate_config_section( def generate_config_section(
self, report_stats: Optional[bool] = None, **kwargs: Any self, report_stats: bool | None = None, **kwargs: Any
) -> str: ) -> str:
if report_stats is not None: if report_stats is not None:
res = "report_stats: %s\n" % ("true" if report_stats else "false") res = "report_stats: %s\n" % ("true" if report_stats else "false")

View File

@ -21,7 +21,7 @@
import importlib.resources as importlib_resources import importlib.resources as importlib_resources
import json import json
import re import re
from typing import Any, Iterable, Optional, Pattern from typing import Any, Iterable, Pattern
from urllib import parse as urlparse from urllib import parse as urlparse
import attr import attr
@ -39,7 +39,7 @@ class OEmbedEndpointConfig:
# The patterns to match. # The patterns to match.
url_patterns: list[Pattern[str]] url_patterns: list[Pattern[str]]
# The supported formats. # The supported formats.
formats: Optional[list[str]] formats: list[str] | None
class OembedConfig(Config): class OembedConfig(Config):

View File

@ -21,7 +21,7 @@
# #
from collections import Counter from collections import Counter
from typing import Any, Collection, Iterable, Mapping, Optional from typing import Any, Collection, Iterable, Mapping
import attr import attr
@ -276,7 +276,7 @@ def _parse_oidc_config_dict(
) from e ) from e
client_secret_jwt_key_config = oidc_config.get("client_secret_jwt_key") client_secret_jwt_key_config = oidc_config.get("client_secret_jwt_key")
client_secret_jwt_key: Optional[OidcProviderClientSecretJwtKey] = None client_secret_jwt_key: OidcProviderClientSecretJwtKey | None = None
if client_secret_jwt_key_config is not None: if client_secret_jwt_key_config is not None:
keyfile = client_secret_jwt_key_config.get("key_file") keyfile = client_secret_jwt_key_config.get("key_file")
if keyfile: if keyfile:
@ -384,10 +384,10 @@ class OidcProviderConfig:
idp_name: str idp_name: str
# Optional MXC URI for icon for this IdP. # Optional MXC URI for icon for this IdP.
idp_icon: Optional[str] idp_icon: str | None
# Optional brand identifier for this IdP. # Optional brand identifier for this IdP.
idp_brand: Optional[str] idp_brand: str | None
# whether the OIDC discovery mechanism is used to discover endpoints # whether the OIDC discovery mechanism is used to discover endpoints
discover: bool discover: bool
@ -401,11 +401,11 @@ class OidcProviderConfig:
# oauth2 client secret to use. if `None`, use client_secret_jwt_key to generate # oauth2 client secret to use. if `None`, use client_secret_jwt_key to generate
# a secret. # a secret.
client_secret: Optional[str] client_secret: str | None
# key to use to construct a JWT to use as a client secret. May be `None` if # key to use to construct a JWT to use as a client secret. May be `None` if
# `client_secret` is set. # `client_secret` is set.
client_secret_jwt_key: Optional[OidcProviderClientSecretJwtKey] client_secret_jwt_key: OidcProviderClientSecretJwtKey | None
# auth method to use when exchanging the token. # auth method to use when exchanging the token.
# Valid values are 'client_secret_basic', 'client_secret_post' and # Valid values are 'client_secret_basic', 'client_secret_post' and
@ -416,7 +416,7 @@ class OidcProviderConfig:
# Valid values are 'auto', 'always', and 'never'. # Valid values are 'auto', 'always', and 'never'.
pkce_method: str pkce_method: str
id_token_signing_alg_values_supported: Optional[list[str]] id_token_signing_alg_values_supported: list[str] | None
""" """
List of the JWS signing algorithms (`alg` values) that are supported for signing the List of the JWS signing algorithms (`alg` values) that are supported for signing the
`id_token`. `id_token`.
@ -448,18 +448,18 @@ class OidcProviderConfig:
scopes: Collection[str] scopes: Collection[str]
# the oauth2 authorization endpoint. Required if discovery is disabled. # the oauth2 authorization endpoint. Required if discovery is disabled.
authorization_endpoint: Optional[str] authorization_endpoint: str | None
# the oauth2 token endpoint. Required if discovery is disabled. # the oauth2 token endpoint. Required if discovery is disabled.
token_endpoint: Optional[str] token_endpoint: str | None
# the OIDC userinfo endpoint. Required if discovery is disabled and the # the OIDC userinfo endpoint. Required if discovery is disabled and the
# "openid" scope is not requested. # "openid" scope is not requested.
userinfo_endpoint: Optional[str] userinfo_endpoint: str | None
# URI where to fetch the JWKS. Required if discovery is disabled and the # URI where to fetch the JWKS. Required if discovery is disabled and the
# "openid" scope is used. # "openid" scope is used.
jwks_uri: Optional[str] jwks_uri: str | None
# Whether Synapse should react to backchannel logouts # Whether Synapse should react to backchannel logouts
backchannel_logout_enabled: bool backchannel_logout_enabled: bool
@ -474,7 +474,7 @@ class OidcProviderConfig:
# values are: "auto" or "userinfo_endpoint". # values are: "auto" or "userinfo_endpoint".
user_profile_method: str user_profile_method: str
redirect_uri: Optional[str] redirect_uri: str | None
""" """
An optional replacement for Synapse's hardcoded `redirect_uri` URL An optional replacement for Synapse's hardcoded `redirect_uri` URL
(`<public_baseurl>/_synapse/client/oidc/callback`). This can be used to send (`<public_baseurl>/_synapse/client/oidc/callback`). This can be used to send

View File

@ -19,7 +19,7 @@
# #
# #
from typing import Any, Optional, cast from typing import Any, cast
import attr import attr
@ -39,7 +39,7 @@ class RatelimitSettings:
cls, cls,
config: dict[str, Any], config: dict[str, Any],
key: str, key: str,
defaults: Optional[dict[str, float]] = None, defaults: dict[str, float] | None = None,
) -> "RatelimitSettings": ) -> "RatelimitSettings":
"""Parse config[key] as a new-style rate limiter config. """Parse config[key] as a new-style rate limiter config.

View File

@ -20,7 +20,7 @@
# #
# #
import argparse import argparse
from typing import Any, Optional from typing import Any
from synapse.api.constants import RoomCreationPreset from synapse.api.constants import RoomCreationPreset
from synapse.config._base import Config, ConfigError, read_file from synapse.config._base import Config, ConfigError, read_file
@ -181,7 +181,7 @@ class RegistrationConfig(Config):
refreshable_access_token_lifetime = self.parse_duration( refreshable_access_token_lifetime = self.parse_duration(
refreshable_access_token_lifetime refreshable_access_token_lifetime
) )
self.refreshable_access_token_lifetime: Optional[int] = ( self.refreshable_access_token_lifetime: int | None = (
refreshable_access_token_lifetime refreshable_access_token_lifetime
) )
@ -226,7 +226,7 @@ class RegistrationConfig(Config):
refresh_token_lifetime = config.get("refresh_token_lifetime") refresh_token_lifetime = config.get("refresh_token_lifetime")
if refresh_token_lifetime is not None: if refresh_token_lifetime is not None:
refresh_token_lifetime = self.parse_duration(refresh_token_lifetime) refresh_token_lifetime = self.parse_duration(refresh_token_lifetime)
self.refresh_token_lifetime: Optional[int] = refresh_token_lifetime self.refresh_token_lifetime: int | None = refresh_token_lifetime
if ( if (
self.session_lifetime is not None self.session_lifetime is not None

View File

@ -20,7 +20,7 @@
# #
import logging import logging
from typing import Any, Optional from typing import Any
import attr import attr
@ -35,8 +35,8 @@ class RetentionPurgeJob:
"""Object describing the configuration of the manhole""" """Object describing the configuration of the manhole"""
interval: int interval: int
shortest_max_lifetime: Optional[int] shortest_max_lifetime: int | None
longest_max_lifetime: Optional[int] longest_max_lifetime: int | None
class RetentionConfig(Config): class RetentionConfig(Config):

View File

@ -25,7 +25,7 @@ import logging
import os.path import os.path
import urllib.parse import urllib.parse
from textwrap import indent from textwrap import indent
from typing import Any, Iterable, Optional, TypedDict, Union from typing import Any, Iterable, TypedDict
from urllib.request import getproxies_environment from urllib.request import getproxies_environment
import attr import attr
@ -95,9 +95,9 @@ def _6to4(network: IPNetwork) -> IPNetwork:
def generate_ip_set( def generate_ip_set(
ip_addresses: Optional[Iterable[str]], ip_addresses: Iterable[str] | None,
extra_addresses: Optional[Iterable[str]] = None, extra_addresses: Iterable[str] | None = None,
config_path: Optional[StrSequence] = None, config_path: StrSequence | None = None,
) -> IPSet: ) -> IPSet:
""" """
Generate an IPSet from a list of IP addresses or CIDRs. Generate an IPSet from a list of IP addresses or CIDRs.
@ -230,8 +230,8 @@ class HttpListenerConfig:
x_forwarded: bool = False x_forwarded: bool = False
resources: list[HttpResourceConfig] = attr.Factory(list) resources: list[HttpResourceConfig] = attr.Factory(list)
additional_resources: dict[str, dict] = attr.Factory(dict) additional_resources: dict[str, dict] = attr.Factory(dict)
tag: Optional[str] = None tag: str | None = None
request_id_header: Optional[str] = None request_id_header: str | None = None
@attr.s(slots=True, frozen=True, auto_attribs=True) @attr.s(slots=True, frozen=True, auto_attribs=True)
@ -244,7 +244,7 @@ class TCPListenerConfig:
tls: bool = False tls: bool = False
# http_options is only populated if type=http # http_options is only populated if type=http
http_options: Optional[HttpListenerConfig] = None http_options: HttpListenerConfig | None = None
def get_site_tag(self) -> str: def get_site_tag(self) -> str:
"""Retrieves http_options.tag if it exists, otherwise the port number.""" """Retrieves http_options.tag if it exists, otherwise the port number."""
@ -269,7 +269,7 @@ class UnixListenerConfig:
type: str = attr.ib(validator=attr.validators.in_(KNOWN_LISTENER_TYPES)) type: str = attr.ib(validator=attr.validators.in_(KNOWN_LISTENER_TYPES))
# http_options is only populated if type=http # http_options is only populated if type=http
http_options: Optional[HttpListenerConfig] = None http_options: HttpListenerConfig | None = None
def get_site_tag(self) -> str: def get_site_tag(self) -> str:
return "unix" return "unix"
@ -279,7 +279,7 @@ class UnixListenerConfig:
return False return False
ListenerConfig = Union[TCPListenerConfig, UnixListenerConfig] ListenerConfig = TCPListenerConfig | UnixListenerConfig
@attr.s(slots=True, frozen=True, auto_attribs=True) @attr.s(slots=True, frozen=True, auto_attribs=True)
@ -288,14 +288,14 @@ class ManholeConfig:
username: str = attr.ib(validator=attr.validators.instance_of(str)) username: str = attr.ib(validator=attr.validators.instance_of(str))
password: str = attr.ib(validator=attr.validators.instance_of(str)) password: str = attr.ib(validator=attr.validators.instance_of(str))
priv_key: Optional[Key] priv_key: Key | None
pub_key: Optional[Key] pub_key: Key | None
@attr.s(frozen=True) @attr.s(frozen=True)
class LimitRemoteRoomsConfig: class LimitRemoteRoomsConfig:
enabled: bool = attr.ib(validator=attr.validators.instance_of(bool), default=False) enabled: bool = attr.ib(validator=attr.validators.instance_of(bool), default=False)
complexity: Union[float, int] = attr.ib( complexity: float | int = attr.ib(
validator=attr.validators.instance_of((float, int)), # noqa validator=attr.validators.instance_of((float, int)), # noqa
default=1.0, default=1.0,
) )
@ -313,11 +313,11 @@ class ProxyConfigDictionary(TypedDict):
Dictionary of proxy settings suitable for interacting with `urllib.request` API's Dictionary of proxy settings suitable for interacting with `urllib.request` API's
""" """
http: Optional[str] http: str | None
""" """
Proxy server to use for HTTP requests. Proxy server to use for HTTP requests.
""" """
https: Optional[str] https: str | None
""" """
Proxy server to use for HTTPS requests. Proxy server to use for HTTPS requests.
""" """
@ -336,15 +336,15 @@ class ProxyConfig:
Synapse configuration for HTTP proxy settings. Synapse configuration for HTTP proxy settings.
""" """
http_proxy: Optional[str] http_proxy: str | None
""" """
Proxy server to use for HTTP requests. Proxy server to use for HTTP requests.
""" """
https_proxy: Optional[str] https_proxy: str | None
""" """
Proxy server to use for HTTPS requests. Proxy server to use for HTTPS requests.
""" """
no_proxy_hosts: Optional[list[str]] no_proxy_hosts: list[str] | None
""" """
List of hosts, IP addresses, or IP ranges in CIDR format which should not use the List of hosts, IP addresses, or IP ranges in CIDR format which should not use the
proxy. Synapse will directly connect to these hosts. proxy. Synapse will directly connect to these hosts.
@ -607,7 +607,7 @@ class ServerConfig(Config):
# before redacting them. # before redacting them.
redaction_retention_period = config.get("redaction_retention_period", "7d") redaction_retention_period = config.get("redaction_retention_period", "7d")
if redaction_retention_period is not None: if redaction_retention_period is not None:
self.redaction_retention_period: Optional[int] = self.parse_duration( self.redaction_retention_period: int | None = self.parse_duration(
redaction_retention_period redaction_retention_period
) )
else: else:
@ -618,7 +618,7 @@ class ServerConfig(Config):
"forgotten_room_retention_period", None "forgotten_room_retention_period", None
) )
if forgotten_room_retention_period is not None: if forgotten_room_retention_period is not None:
self.forgotten_room_retention_period: Optional[int] = self.parse_duration( self.forgotten_room_retention_period: int | None = self.parse_duration(
forgotten_room_retention_period forgotten_room_retention_period
) )
else: else:
@ -627,7 +627,7 @@ class ServerConfig(Config):
# How long to keep entries in the `users_ips` table. # How long to keep entries in the `users_ips` table.
user_ips_max_age = config.get("user_ips_max_age", "28d") user_ips_max_age = config.get("user_ips_max_age", "28d")
if user_ips_max_age is not None: if user_ips_max_age is not None:
self.user_ips_max_age: Optional[int] = self.parse_duration(user_ips_max_age) self.user_ips_max_age: int | None = self.parse_duration(user_ips_max_age)
else: else:
self.user_ips_max_age = None self.user_ips_max_age = None
@ -864,11 +864,11 @@ class ServerConfig(Config):
) )
# Whitelist of domain names that given next_link parameters must have # Whitelist of domain names that given next_link parameters must have
next_link_domain_whitelist: Optional[list[str]] = config.get( next_link_domain_whitelist: list[str] | None = config.get(
"next_link_domain_whitelist" "next_link_domain_whitelist"
) )
self.next_link_domain_whitelist: Optional[set[str]] = None self.next_link_domain_whitelist: set[str] | None = None
if next_link_domain_whitelist is not None: if next_link_domain_whitelist is not None:
if not isinstance(next_link_domain_whitelist, list): if not isinstance(next_link_domain_whitelist, list):
raise ConfigError("'next_link_domain_whitelist' must be a list") raise ConfigError("'next_link_domain_whitelist' must be a list")
@ -880,7 +880,7 @@ class ServerConfig(Config):
if not isinstance(templates_config, dict): if not isinstance(templates_config, dict):
raise ConfigError("The 'templates' section must be a dictionary") raise ConfigError("The 'templates' section must be a dictionary")
self.custom_template_directory: Optional[str] = templates_config.get( self.custom_template_directory: str | None = templates_config.get(
"custom_template_directory" "custom_template_directory"
) )
if self.custom_template_directory is not None and not isinstance( if self.custom_template_directory is not None and not isinstance(
@ -896,12 +896,12 @@ class ServerConfig(Config):
config.get("exclude_rooms_from_sync") or [] config.get("exclude_rooms_from_sync") or []
) )
delete_stale_devices_after: Optional[str] = ( delete_stale_devices_after: str | None = (
config.get("delete_stale_devices_after") or None config.get("delete_stale_devices_after") or None
) )
if delete_stale_devices_after is not None: if delete_stale_devices_after is not None:
self.delete_stale_devices_after: Optional[int] = self.parse_duration( self.delete_stale_devices_after: int | None = self.parse_duration(
delete_stale_devices_after delete_stale_devices_after
) )
else: else:
@ -910,7 +910,7 @@ class ServerConfig(Config):
# The maximum allowed delay duration for delayed events (MSC4140). # The maximum allowed delay duration for delayed events (MSC4140).
max_event_delay_duration = config.get("max_event_delay_duration") max_event_delay_duration = config.get("max_event_delay_duration")
if max_event_delay_duration is not None: if max_event_delay_duration is not None:
self.max_event_delay_ms: Optional[int] = self.parse_duration( self.max_event_delay_ms: int | None = self.parse_duration(
max_event_delay_duration max_event_delay_duration
) )
if self.max_event_delay_ms <= 0: if self.max_event_delay_ms <= 0:
@ -927,7 +927,7 @@ class ServerConfig(Config):
data_dir_path: str, data_dir_path: str,
server_name: str, server_name: str,
open_private_ports: bool, open_private_ports: bool,
listeners: Optional[list[dict]], listeners: list[dict] | None,
**kwargs: Any, **kwargs: Any,
) -> str: ) -> str:
_, bind_port = parse_and_validate_server_name(server_name) _, bind_port = parse_and_validate_server_name(server_name)
@ -1028,7 +1028,7 @@ class ServerConfig(Config):
help="Turn on the twisted telnet manhole service on the given port.", help="Turn on the twisted telnet manhole service on the given port.",
) )
def read_gc_intervals(self, durations: Any) -> Optional[tuple[float, float, float]]: def read_gc_intervals(self, durations: Any) -> tuple[float, float, float] | None:
"""Reads the three durations for the GC min interval option, returning seconds.""" """Reads the three durations for the GC min interval option, returning seconds."""
if durations is None: if durations is None:
return None return None
@ -1066,8 +1066,8 @@ def is_threepid_reserved(
def read_gc_thresholds( def read_gc_thresholds(
thresholds: Optional[list[Any]], thresholds: list[Any] | None,
) -> Optional[tuple[int, int, int]]: ) -> tuple[int, int, int] | None:
"""Reads the three integer thresholds for garbage collection. Ensures that """Reads the three integer thresholds for garbage collection. Ensures that
the thresholds are integers if thresholds are supplied. the thresholds are integers if thresholds are supplied.
""" """

View File

@ -18,7 +18,7 @@
# #
# #
from typing import Any, Optional from typing import Any
from synapse.types import JsonDict, UserID from synapse.types import JsonDict, UserID
@ -58,12 +58,12 @@ class ServerNoticesConfig(Config):
def __init__(self, *args: Any): def __init__(self, *args: Any):
super().__init__(*args) super().__init__(*args)
self.server_notices_mxid: Optional[str] = None self.server_notices_mxid: str | None = None
self.server_notices_mxid_display_name: Optional[str] = None self.server_notices_mxid_display_name: str | None = None
self.server_notices_mxid_avatar_url: Optional[str] = None self.server_notices_mxid_avatar_url: str | None = None
self.server_notices_room_name: Optional[str] = None self.server_notices_room_name: str | None = None
self.server_notices_room_avatar_url: Optional[str] = None self.server_notices_room_avatar_url: str | None = None
self.server_notices_room_topic: Optional[str] = None self.server_notices_room_topic: str | None = None
self.server_notices_auto_join: bool = False self.server_notices_auto_join: bool = False
def read_config(self, config: JsonDict, **kwargs: Any) -> None: def read_config(self, config: JsonDict, **kwargs: Any) -> None:

View File

@ -19,7 +19,7 @@
# #
# #
import logging import logging
from typing import Any, Optional from typing import Any
import attr import attr
@ -44,8 +44,8 @@ class SsoAttributeRequirement:
attribute: str attribute: str
# If neither `value` nor `one_of` is given, the attribute must simply exist. # If neither `value` nor `one_of` is given, the attribute must simply exist.
value: Optional[str] = None value: str | None = None
one_of: Optional[list[str]] = None one_of: list[str] | None = None
JSON_SCHEMA = { JSON_SCHEMA = {
"type": "object", "type": "object",

View File

@ -20,7 +20,7 @@
# #
import logging import logging
from typing import Any, Optional, Pattern from typing import Any, Pattern
from matrix_common.regex import glob_to_regex from matrix_common.regex import glob_to_regex
@ -135,8 +135,8 @@ class TlsConfig(Config):
"use_insecure_ssl_client_just_for_testing_do_not_use" "use_insecure_ssl_client_just_for_testing_do_not_use"
) )
self.tls_certificate: Optional[crypto.X509] = None self.tls_certificate: crypto.X509 | None = None
self.tls_private_key: Optional[crypto.PKey] = None self.tls_private_key: crypto.PKey | None = None
def read_certificate_from_disk(self) -> None: def read_certificate_from_disk(self) -> None:
""" """
@ -147,8 +147,8 @@ class TlsConfig(Config):
def generate_config_section( def generate_config_section(
self, self,
tls_certificate_path: Optional[str], tls_certificate_path: str | None,
tls_private_key_path: Optional[str], tls_private_key_path: str | None,
**kwargs: Any, **kwargs: Any,
) -> str: ) -> str:
"""If the TLS paths are not specified the default will be certs in the """If the TLS paths are not specified the default will be certs in the

View File

@ -12,7 +12,7 @@
# <https://www.gnu.org/licenses/agpl-3.0.html>. # <https://www.gnu.org/licenses/agpl-3.0.html>.
# #
from typing import Any, Optional from typing import Any
from synapse.api.constants import UserTypes from synapse.api.constants import UserTypes
from synapse.types import JsonDict from synapse.types import JsonDict
@ -26,9 +26,7 @@ class UserTypesConfig(Config):
def read_config(self, config: JsonDict, **kwargs: Any) -> None: def read_config(self, config: JsonDict, **kwargs: Any) -> None:
user_types: JsonDict = config.get("user_types", {}) user_types: JsonDict = config.get("user_types", {})
self.default_user_type: Optional[str] = user_types.get( self.default_user_type: str | None = user_types.get("default_user_type", None)
"default_user_type", None
)
self.extra_user_types: list[str] = user_types.get("extra_user_types", []) self.extra_user_types: list[str] = user_types.get("extra_user_types", [])
all_user_types: list[str] = [] all_user_types: list[str] = []

View File

@ -22,7 +22,7 @@
import argparse import argparse
import logging import logging
from typing import Any, Optional, Union from typing import Any
import attr import attr
from pydantic import ( from pydantic import (
@ -79,7 +79,7 @@ MAIN_PROCESS_INSTANCE_MAP_NAME = "main"
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _instance_to_list_converter(obj: Union[str, list[str]]) -> list[str]: def _instance_to_list_converter(obj: str | list[str]) -> list[str]:
"""Helper for allowing parsing a string or list of strings to a config """Helper for allowing parsing a string or list of strings to a config
option expecting a list of strings. option expecting a list of strings.
""" """
@ -119,7 +119,7 @@ class InstanceUnixLocationConfig(ParseModel):
return f"{self.path}" return f"{self.path}"
InstanceLocationConfig = Union[InstanceTcpLocationConfig, InstanceUnixLocationConfig] InstanceLocationConfig = InstanceTcpLocationConfig | InstanceUnixLocationConfig
@attr.s @attr.s
@ -190,7 +190,7 @@ class OutboundFederationRestrictedTo:
locations: list of instance locations to connect to proxy via. locations: list of instance locations to connect to proxy via.
""" """
instances: Optional[list[str]] instances: list[str] | None
locations: list[InstanceLocationConfig] = attr.Factory(list) locations: list[InstanceLocationConfig] = attr.Factory(list)
def __contains__(self, instance: str) -> bool: def __contains__(self, instance: str) -> bool:
@ -246,7 +246,7 @@ class WorkerConfig(Config):
if worker_replication_secret_path: if worker_replication_secret_path:
if worker_replication_secret: if worker_replication_secret:
raise ConfigError(CONFLICTING_WORKER_REPLICATION_SECRET_OPTS_ERROR) raise ConfigError(CONFLICTING_WORKER_REPLICATION_SECRET_OPTS_ERROR)
self.worker_replication_secret: Optional[str] = read_file( self.worker_replication_secret: str | None = read_file(
worker_replication_secret_path, ("worker_replication_secret_path",) worker_replication_secret_path, ("worker_replication_secret_path",)
).strip() ).strip()
else: else:
@ -341,7 +341,7 @@ class WorkerConfig(Config):
% MAIN_PROCESS_INSTANCE_MAP_NAME % MAIN_PROCESS_INSTANCE_MAP_NAME
) )
# type-ignore: the expression `Union[A, B]` is not a Type[Union[A, B]] currently # type-ignore: the expression `A | B` is not a `type[A | B]` currently
self.instance_map: dict[str, InstanceLocationConfig] = ( self.instance_map: dict[str, InstanceLocationConfig] = (
parse_and_validate_mapping( parse_and_validate_mapping(
instance_map, instance_map,

View File

@ -21,7 +21,7 @@
import abc import abc
import logging import logging
from typing import TYPE_CHECKING, Callable, Iterable, Optional from typing import TYPE_CHECKING, Callable, Iterable
import attr import attr
from signedjson.key import ( from signedjson.key import (
@ -150,7 +150,7 @@ class Keyring:
""" """
def __init__( def __init__(
self, hs: "HomeServer", key_fetchers: "Optional[Iterable[KeyFetcher]]" = None self, hs: "HomeServer", key_fetchers: "Iterable[KeyFetcher] | None" = None
): ):
self.server_name = hs.hostname self.server_name = hs.hostname

View File

@ -160,7 +160,7 @@ def validate_event_for_room_version(event: "EventBase") -> None:
async def check_state_independent_auth_rules( async def check_state_independent_auth_rules(
store: _EventSourceStore, store: _EventSourceStore,
event: "EventBase", event: "EventBase",
batched_auth_events: Optional[Mapping[str, "EventBase"]] = None, batched_auth_events: Mapping[str, "EventBase"] | None = None,
) -> None: ) -> None:
"""Check that an event complies with auth rules that are independent of room state """Check that an event complies with auth rules that are independent of room state
@ -788,7 +788,7 @@ def _check_joined_room(
def get_send_level( def get_send_level(
etype: str, state_key: Optional[str], power_levels_event: Optional["EventBase"] etype: str, state_key: str | None, power_levels_event: Optional["EventBase"]
) -> int: ) -> int:
"""Get the power level required to send an event of a given type """Get the power level required to send an event of a given type
@ -989,7 +989,7 @@ def _check_power_levels(
user_level = get_user_power_level(event.user_id, auth_events) user_level = get_user_power_level(event.user_id, auth_events)
# Check other levels: # Check other levels:
levels_to_check: list[tuple[str, Optional[str]]] = [ levels_to_check: list[tuple[str, str | None]] = [
("users_default", None), ("users_default", None),
("events_default", None), ("events_default", None),
("state_default", None), ("state_default", None),
@ -1027,12 +1027,12 @@ def _check_power_levels(
new_loc = new_loc.get(dir, {}) new_loc = new_loc.get(dir, {})
if level_to_check in old_loc: if level_to_check in old_loc:
old_level: Optional[int] = int(old_loc[level_to_check]) old_level: int | None = int(old_loc[level_to_check])
else: else:
old_level = None old_level = None
if level_to_check in new_loc: if level_to_check in new_loc:
new_level: Optional[int] = int(new_loc[level_to_check]) new_level: int | None = int(new_loc[level_to_check])
else: else:
new_level = None new_level = None

View File

@ -28,7 +28,6 @@ from typing import (
Generic, Generic,
Iterable, Iterable,
Literal, Literal,
Optional,
TypeVar, TypeVar,
Union, Union,
overload, overload,
@ -90,21 +89,21 @@ class DictProperty(Generic[T]):
def __get__( def __get__(
self, self,
instance: Literal[None], instance: Literal[None],
owner: Optional[type[_DictPropertyInstance]] = None, owner: type[_DictPropertyInstance] | None = None,
) -> "DictProperty": ... ) -> "DictProperty": ...
@overload @overload
def __get__( def __get__(
self, self,
instance: _DictPropertyInstance, instance: _DictPropertyInstance,
owner: Optional[type[_DictPropertyInstance]] = None, owner: type[_DictPropertyInstance] | None = None,
) -> T: ... ) -> T: ...
def __get__( def __get__(
self, self,
instance: Optional[_DictPropertyInstance], instance: _DictPropertyInstance | None,
owner: Optional[type[_DictPropertyInstance]] = None, owner: type[_DictPropertyInstance] | None = None,
) -> Union[T, "DictProperty"]: ) -> T | "DictProperty":
# if the property is accessed as a class property rather than an instance # if the property is accessed as a class property rather than an instance
# property, return the property itself rather than the value # property, return the property itself rather than the value
if instance is None: if instance is None:
@ -156,21 +155,21 @@ class DefaultDictProperty(DictProperty, Generic[T]):
def __get__( def __get__(
self, self,
instance: Literal[None], instance: Literal[None],
owner: Optional[type[_DictPropertyInstance]] = None, owner: type[_DictPropertyInstance] | None = None,
) -> "DefaultDictProperty": ... ) -> "DefaultDictProperty": ...
@overload @overload
def __get__( def __get__(
self, self,
instance: _DictPropertyInstance, instance: _DictPropertyInstance,
owner: Optional[type[_DictPropertyInstance]] = None, owner: type[_DictPropertyInstance] | None = None,
) -> T: ... ) -> T: ...
def __get__( def __get__(
self, self,
instance: Optional[_DictPropertyInstance], instance: _DictPropertyInstance | None,
owner: Optional[type[_DictPropertyInstance]] = None, owner: type[_DictPropertyInstance] | None = None,
) -> Union[T, "DefaultDictProperty"]: ) -> T | "DefaultDictProperty":
if instance is None: if instance is None:
return self return self
assert isinstance(instance, EventBase) assert isinstance(instance, EventBase)
@ -191,7 +190,7 @@ class EventBase(metaclass=abc.ABCMeta):
signatures: dict[str, dict[str, str]], signatures: dict[str, dict[str, str]],
unsigned: JsonDict, unsigned: JsonDict,
internal_metadata_dict: JsonDict, internal_metadata_dict: JsonDict,
rejected_reason: Optional[str], rejected_reason: str | None,
): ):
assert room_version.event_format == self.format_version assert room_version.event_format == self.format_version
@ -209,7 +208,7 @@ class EventBase(metaclass=abc.ABCMeta):
hashes: DictProperty[dict[str, str]] = DictProperty("hashes") hashes: DictProperty[dict[str, str]] = DictProperty("hashes")
origin_server_ts: DictProperty[int] = DictProperty("origin_server_ts") origin_server_ts: DictProperty[int] = DictProperty("origin_server_ts")
sender: DictProperty[str] = DictProperty("sender") sender: DictProperty[str] = DictProperty("sender")
# TODO state_key should be Optional[str]. This is generally asserted in Synapse # TODO state_key should be str | None. This is generally asserted in Synapse
# by calling is_state() first (which ensures it is not None), but it is hard (not possible?) # by calling is_state() first (which ensures it is not None), but it is hard (not possible?)
# to properly annotate that calling is_state() asserts that state_key exists # to properly annotate that calling is_state() asserts that state_key exists
# and is non-None. It would be better to replace such direct references with # and is non-None. It would be better to replace such direct references with
@ -231,7 +230,7 @@ class EventBase(metaclass=abc.ABCMeta):
return self.content["membership"] return self.content["membership"]
@property @property
def redacts(self) -> Optional[str]: def redacts(self) -> str | None:
"""MSC2176 moved the redacts field into the content.""" """MSC2176 moved the redacts field into the content."""
if self.room_version.updated_redaction_rules: if self.room_version.updated_redaction_rules:
return self.content.get("redacts") return self.content.get("redacts")
@ -240,7 +239,7 @@ class EventBase(metaclass=abc.ABCMeta):
def is_state(self) -> bool: def is_state(self) -> bool:
return self.get_state_key() is not None return self.get_state_key() is not None
def get_state_key(self) -> Optional[str]: def get_state_key(self) -> str | None:
"""Get the state key of this event, or None if it's not a state event""" """Get the state key of this event, or None if it's not a state event"""
return self._dict.get("state_key") return self._dict.get("state_key")
@ -250,13 +249,13 @@ class EventBase(metaclass=abc.ABCMeta):
return d return d
def get(self, key: str, default: Optional[Any] = None) -> Any: def get(self, key: str, default: Any | None = None) -> Any:
return self._dict.get(key, default) return self._dict.get(key, default)
def get_internal_metadata_dict(self) -> JsonDict: def get_internal_metadata_dict(self) -> JsonDict:
return self.internal_metadata.get_dict() return self.internal_metadata.get_dict()
def get_pdu_json(self, time_now: Optional[int] = None) -> JsonDict: def get_pdu_json(self, time_now: int | None = None) -> JsonDict:
pdu_json = self.get_dict() pdu_json = self.get_dict()
if time_now is not None and "age_ts" in pdu_json["unsigned"]: if time_now is not None and "age_ts" in pdu_json["unsigned"]:
@ -283,13 +282,13 @@ class EventBase(metaclass=abc.ABCMeta):
return template_json return template_json
def __getitem__(self, field: str) -> Optional[Any]: def __getitem__(self, field: str) -> Any | None:
return self._dict[field] return self._dict[field]
def __contains__(self, field: str) -> bool: def __contains__(self, field: str) -> bool:
return field in self._dict return field in self._dict
def items(self) -> list[tuple[str, Optional[Any]]]: def items(self) -> list[tuple[str, Any | None]]:
return list(self._dict.items()) return list(self._dict.items())
def keys(self) -> Iterable[str]: def keys(self) -> Iterable[str]:
@ -348,8 +347,8 @@ class FrozenEvent(EventBase):
self, self,
event_dict: JsonDict, event_dict: JsonDict,
room_version: RoomVersion, room_version: RoomVersion,
internal_metadata_dict: Optional[JsonDict] = None, internal_metadata_dict: JsonDict | None = None,
rejected_reason: Optional[str] = None, rejected_reason: str | None = None,
): ):
internal_metadata_dict = internal_metadata_dict or {} internal_metadata_dict = internal_metadata_dict or {}
@ -400,8 +399,8 @@ class FrozenEventV2(EventBase):
self, self,
event_dict: JsonDict, event_dict: JsonDict,
room_version: RoomVersion, room_version: RoomVersion,
internal_metadata_dict: Optional[JsonDict] = None, internal_metadata_dict: JsonDict | None = None,
rejected_reason: Optional[str] = None, rejected_reason: str | None = None,
): ):
internal_metadata_dict = internal_metadata_dict or {} internal_metadata_dict = internal_metadata_dict or {}
@ -427,7 +426,7 @@ class FrozenEventV2(EventBase):
else: else:
frozen_dict = event_dict frozen_dict = event_dict
self._event_id: Optional[str] = None self._event_id: str | None = None
super().__init__( super().__init__(
frozen_dict, frozen_dict,
@ -502,8 +501,8 @@ class FrozenEventV4(FrozenEventV3):
self, self,
event_dict: JsonDict, event_dict: JsonDict,
room_version: RoomVersion, room_version: RoomVersion,
internal_metadata_dict: Optional[JsonDict] = None, internal_metadata_dict: JsonDict | None = None,
rejected_reason: Optional[str] = None, rejected_reason: str | None = None,
): ):
super().__init__( super().__init__(
event_dict=event_dict, event_dict=event_dict,
@ -511,7 +510,7 @@ class FrozenEventV4(FrozenEventV3):
internal_metadata_dict=internal_metadata_dict, internal_metadata_dict=internal_metadata_dict,
rejected_reason=rejected_reason, rejected_reason=rejected_reason,
) )
self._room_id: Optional[str] = None self._room_id: str | None = None
@property @property
def room_id(self) -> str: def room_id(self) -> str:
@ -554,7 +553,7 @@ class FrozenEventV4(FrozenEventV3):
def _event_type_from_format_version( def _event_type_from_format_version(
format_version: int, format_version: int,
) -> type[Union[FrozenEvent, FrozenEventV2, FrozenEventV3]]: ) -> type[FrozenEvent | FrozenEventV2 | FrozenEventV3]:
"""Returns the python type to use to construct an Event object for the """Returns the python type to use to construct an Event object for the
given event format version. given event format version.
@ -580,8 +579,8 @@ def _event_type_from_format_version(
def make_event_from_dict( def make_event_from_dict(
event_dict: JsonDict, event_dict: JsonDict,
room_version: RoomVersion = RoomVersions.V1, room_version: RoomVersion = RoomVersions.V1,
internal_metadata_dict: Optional[JsonDict] = None, internal_metadata_dict: JsonDict | None = None,
rejected_reason: Optional[str] = None, rejected_reason: str | None = None,
) -> EventBase: ) -> EventBase:
"""Construct an EventBase from the given event dict""" """Construct an EventBase from the given event dict"""
event_type = _event_type_from_format_version(room_version.event_format) event_type = _event_type_from_format_version(room_version.event_format)
@ -598,10 +597,10 @@ class _EventRelation:
rel_type: str rel_type: str
# The aggregation key. Will be None if the rel_type is not m.annotation or is # The aggregation key. Will be None if the rel_type is not m.annotation or is
# not a string. # not a string.
aggregation_key: Optional[str] aggregation_key: str | None
def relation_from_event(event: EventBase) -> Optional[_EventRelation]: def relation_from_event(event: EventBase) -> _EventRelation | None:
""" """
Attempt to parse relation information an event. Attempt to parse relation information an event.

View File

@ -19,7 +19,7 @@
# #
# #
import logging import logging
from typing import TYPE_CHECKING, Any, Optional, Union from typing import TYPE_CHECKING, Any
import attr import attr
from signedjson.types import SigningKey from signedjson.types import SigningKey
@ -83,7 +83,7 @@ class EventBuilder:
room_version: RoomVersion room_version: RoomVersion
# MSC4291 makes the room ID == the create event ID. This means the create event has no room_id. # MSC4291 makes the room ID == the create event ID. This means the create event has no room_id.
room_id: Optional[str] room_id: str | None
type: str type: str
sender: str sender: str
@ -92,9 +92,9 @@ class EventBuilder:
# These only exist on a subset of events, so they raise AttributeError if # These only exist on a subset of events, so they raise AttributeError if
# someone tries to get them when they don't exist. # someone tries to get them when they don't exist.
_state_key: Optional[str] = None _state_key: str | None = None
_redacts: Optional[str] = None _redacts: str | None = None
_origin_server_ts: Optional[int] = None _origin_server_ts: int | None = None
internal_metadata: EventInternalMetadata = attr.Factory( internal_metadata: EventInternalMetadata = attr.Factory(
lambda: EventInternalMetadata({}) lambda: EventInternalMetadata({})
@ -126,8 +126,8 @@ class EventBuilder:
async def build( async def build(
self, self,
prev_event_ids: list[str], prev_event_ids: list[str],
auth_event_ids: Optional[list[str]], auth_event_ids: list[str] | None,
depth: Optional[int] = None, depth: int | None = None,
) -> EventBase: ) -> EventBase:
"""Transform into a fully signed and hashed event """Transform into a fully signed and hashed event
@ -205,8 +205,8 @@ class EventBuilder:
format_version = self.room_version.event_format format_version = self.room_version.event_format
# The types of auth/prev events changes between event versions. # The types of auth/prev events changes between event versions.
prev_events: Union[StrCollection, list[tuple[str, dict[str, str]]]] prev_events: StrCollection | list[tuple[str, dict[str, str]]]
auth_events: Union[list[str], list[tuple[str, dict[str, str]]]] auth_events: list[str] | list[tuple[str, dict[str, str]]]
if format_version == EventFormatVersions.ROOM_V1_V2: if format_version == EventFormatVersions.ROOM_V1_V2:
auth_events = await self._store.add_event_hashes(auth_event_ids) auth_events = await self._store.add_event_hashes(auth_event_ids)
prev_events = await self._store.add_event_hashes(prev_event_ids) prev_events = await self._store.add_event_hashes(prev_event_ids)
@ -327,7 +327,7 @@ def create_local_event_from_event_dict(
signing_key: SigningKey, signing_key: SigningKey,
room_version: RoomVersion, room_version: RoomVersion,
event_dict: JsonDict, event_dict: JsonDict,
internal_metadata_dict: Optional[JsonDict] = None, internal_metadata_dict: JsonDict | None = None,
) -> EventBase: ) -> EventBase:
"""Takes a fully formed event dict, ensuring that fields like """Takes a fully formed event dict, ensuring that fields like
`origin_server_ts` have correct values for a locally produced event, `origin_server_ts` have correct values for a locally produced event,

View File

@ -25,9 +25,7 @@ from typing import (
Awaitable, Awaitable,
Callable, Callable,
Iterable, Iterable,
Optional,
TypeVar, TypeVar,
Union,
) )
from typing_extensions import ParamSpec from typing_extensions import ParamSpec
@ -44,7 +42,7 @@ GET_USERS_FOR_STATES_CALLBACK = Callable[
[Iterable[UserPresenceState]], Awaitable[dict[str, set[UserPresenceState]]] [Iterable[UserPresenceState]], Awaitable[dict[str, set[UserPresenceState]]]
] ]
# This must either return a set of strings or the constant PresenceRouter.ALL_USERS. # This must either return a set of strings or the constant PresenceRouter.ALL_USERS.
GET_INTERESTED_USERS_CALLBACK = Callable[[str], Awaitable[Union[set[str], str]]] GET_INTERESTED_USERS_CALLBACK = Callable[[str], Awaitable[set[str] | str]]
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -77,8 +75,8 @@ def load_legacy_presence_router(hs: "HomeServer") -> None:
# All methods that the module provides should be async, but this wasn't enforced # All methods that the module provides should be async, but this wasn't enforced
# in the old module system, so we wrap them if needed # in the old module system, so we wrap them if needed
def async_wrapper( def async_wrapper(
f: Optional[Callable[P, R]], f: Callable[P, R] | None,
) -> Optional[Callable[P, Awaitable[R]]]: ) -> Callable[P, Awaitable[R]] | None:
# f might be None if the callback isn't implemented by the module. In this # f might be None if the callback isn't implemented by the module. In this
# case we don't want to register a callback at all so we return None. # case we don't want to register a callback at all so we return None.
if f is None: if f is None:
@ -95,7 +93,7 @@ def load_legacy_presence_router(hs: "HomeServer") -> None:
return run return run
# Register the hooks through the module API. # Register the hooks through the module API.
hooks: dict[str, Optional[Callable[..., Any]]] = { hooks: dict[str, Callable[..., Any] | None] = {
hook: async_wrapper(getattr(presence_router, hook, None)) hook: async_wrapper(getattr(presence_router, hook, None))
for hook in presence_router_methods for hook in presence_router_methods
} }
@ -118,8 +116,8 @@ class PresenceRouter:
def register_presence_router_callbacks( def register_presence_router_callbacks(
self, self,
get_users_for_states: Optional[GET_USERS_FOR_STATES_CALLBACK] = None, get_users_for_states: GET_USERS_FOR_STATES_CALLBACK | None = None,
get_interested_users: Optional[GET_INTERESTED_USERS_CALLBACK] = None, get_interested_users: GET_INTERESTED_USERS_CALLBACK | None = None,
) -> None: ) -> None:
# PresenceRouter modules are required to implement both of these methods # PresenceRouter modules are required to implement both of these methods
# or neither of them as they are assumed to act in a complementary manner # or neither of them as they are assumed to act in a complementary manner
@ -191,7 +189,7 @@ class PresenceRouter:
return users_for_states return users_for_states
async def get_interested_users(self, user_id: str) -> Union[set[str], str]: async def get_interested_users(self, user_id: str) -> set[str] | str:
""" """
Retrieve a list of users that `user_id` is interested in receiving the Retrieve a list of users that `user_id` is interested in receiving the
presence of. This will be in addition to those they share a room with. presence of. This will be in addition to those they share a room with.

View File

@ -51,7 +51,7 @@ class UnpersistedEventContextBase(ABC):
def __init__(self, storage_controller: "StorageControllers"): def __init__(self, storage_controller: "StorageControllers"):
self._storage: "StorageControllers" = storage_controller self._storage: "StorageControllers" = storage_controller
self.app_service: Optional[ApplicationService] = None self.app_service: ApplicationService | None = None
@abstractmethod @abstractmethod
async def persist( async def persist(
@ -134,20 +134,20 @@ class EventContext(UnpersistedEventContextBase):
_storage: "StorageControllers" _storage: "StorageControllers"
state_group_deltas: dict[tuple[int, int], StateMap[str]] state_group_deltas: dict[tuple[int, int], StateMap[str]]
rejected: Optional[str] = None rejected: str | None = None
_state_group: Optional[int] = None _state_group: int | None = None
state_group_before_event: Optional[int] = None state_group_before_event: int | None = None
_state_delta_due_to_event: Optional[StateMap[str]] = None _state_delta_due_to_event: StateMap[str] | None = None
app_service: Optional[ApplicationService] = None app_service: ApplicationService | None = None
partial_state: bool = False partial_state: bool = False
@staticmethod @staticmethod
def with_state( def with_state(
storage: "StorageControllers", storage: "StorageControllers",
state_group: Optional[int], state_group: int | None,
state_group_before_event: Optional[int], state_group_before_event: int | None,
state_delta_due_to_event: Optional[StateMap[str]], state_delta_due_to_event: StateMap[str] | None,
partial_state: bool, partial_state: bool,
state_group_deltas: dict[tuple[int, int], StateMap[str]], state_group_deltas: dict[tuple[int, int], StateMap[str]],
) -> "EventContext": ) -> "EventContext":
@ -227,7 +227,7 @@ class EventContext(UnpersistedEventContextBase):
return context return context
@property @property
def state_group(self) -> Optional[int]: def state_group(self) -> int | None:
"""The ID of the state group for this event. """The ID of the state group for this event.
Note that state events are persisted with a state group which includes the new Note that state events are persisted with a state group which includes the new
@ -354,13 +354,13 @@ class UnpersistedEventContext(UnpersistedEventContextBase):
""" """
_storage: "StorageControllers" _storage: "StorageControllers"
state_group_before_event: Optional[int] state_group_before_event: int | None
state_group_after_event: Optional[int] state_group_after_event: int | None
state_delta_due_to_event: Optional[StateMap[str]] state_delta_due_to_event: StateMap[str] | None
prev_group_for_state_group_before_event: Optional[int] prev_group_for_state_group_before_event: int | None
delta_ids_to_state_group_before_event: Optional[StateMap[str]] delta_ids_to_state_group_before_event: StateMap[str] | None
partial_state: bool partial_state: bool
state_map_before_event: Optional[StateMap[str]] = None state_map_before_event: StateMap[str] | None = None
@classmethod @classmethod
async def batch_persist_unpersisted_contexts( async def batch_persist_unpersisted_contexts(
@ -511,7 +511,7 @@ class UnpersistedEventContext(UnpersistedEventContextBase):
def _encode_state_group_delta( def _encode_state_group_delta(
state_group_delta: dict[tuple[int, int], StateMap[str]], state_group_delta: dict[tuple[int, int], StateMap[str]],
) -> list[tuple[int, int, Optional[list[tuple[str, str, str]]]]]: ) -> list[tuple[int, int, list[tuple[str, str, str]] | None]]:
if not state_group_delta: if not state_group_delta:
return [] return []
@ -538,8 +538,8 @@ def _decode_state_group_delta(
def _encode_state_dict( def _encode_state_dict(
state_dict: Optional[StateMap[str]], state_dict: StateMap[str] | None,
) -> Optional[list[tuple[str, str, str]]]: ) -> list[tuple[str, str, str]] | None:
"""Since dicts of (type, state_key) -> event_id cannot be serialized in """Since dicts of (type, state_key) -> event_id cannot be serialized in
JSON we need to convert them to a form that can. JSON we need to convert them to a form that can.
""" """
@ -550,8 +550,8 @@ def _encode_state_dict(
def _decode_state_dict( def _decode_state_dict(
input: Optional[list[tuple[str, str, str]]], input: list[tuple[str, str, str]] | None,
) -> Optional[StateMap[str]]: ) -> StateMap[str] | None:
"""Decodes a state dict encoded using `_encode_state_dict` above""" """Decodes a state dict encoded using `_encode_state_dict` above"""
if input is None: if input is None:
return None return None

View File

@ -30,8 +30,6 @@ from typing import (
Mapping, Mapping,
Match, Match,
MutableMapping, MutableMapping,
Optional,
Union,
) )
import attr import attr
@ -415,9 +413,9 @@ class SerializeEventConfig:
event_format: Callable[[JsonDict], JsonDict] = format_event_for_client_v1 event_format: Callable[[JsonDict], JsonDict] = format_event_for_client_v1
# The entity that requested the event. This is used to determine whether to include # The entity that requested the event. This is used to determine whether to include
# the transaction_id in the unsigned section of the event. # the transaction_id in the unsigned section of the event.
requester: Optional[Requester] = None requester: Requester | None = None
# List of event fields to include. If empty, all fields will be returned. # List of event fields to include. If empty, all fields will be returned.
only_event_fields: Optional[list[str]] = None only_event_fields: list[str] | None = None
# Some events can have stripped room state stored in the `unsigned` field. # Some events can have stripped room state stored in the `unsigned` field.
# This is required for invite and knock functionality. If this option is # This is required for invite and knock functionality. If this option is
# False, that state will be removed from the event before it is returned. # False, that state will be removed from the event before it is returned.
@ -439,7 +437,7 @@ def make_config_for_admin(existing: SerializeEventConfig) -> SerializeEventConfi
def serialize_event( def serialize_event(
e: Union[JsonDict, EventBase], e: JsonDict | EventBase,
time_now_ms: int, time_now_ms: int,
*, *,
config: SerializeEventConfig = _DEFAULT_SERIALIZE_EVENT_CONFIG, config: SerializeEventConfig = _DEFAULT_SERIALIZE_EVENT_CONFIG,
@ -480,7 +478,7 @@ def serialize_event(
# If we have a txn_id saved in the internal_metadata, we should include it in the # If we have a txn_id saved in the internal_metadata, we should include it in the
# unsigned section of the event if it was sent by the same session as the one # unsigned section of the event if it was sent by the same session as the one
# requesting the event. # requesting the event.
txn_id: Optional[str] = getattr(e.internal_metadata, "txn_id", None) txn_id: str | None = getattr(e.internal_metadata, "txn_id", None)
if ( if (
txn_id is not None txn_id is not None
and config.requester is not None and config.requester is not None
@ -490,7 +488,7 @@ def serialize_event(
# this includes old events as well as those created by appservice, guests, # this includes old events as well as those created by appservice, guests,
# or with tokens minted with the admin API. For those events, fallback # or with tokens minted with the admin API. For those events, fallback
# to using the access token instead. # to using the access token instead.
event_device_id: Optional[str] = getattr(e.internal_metadata, "device_id", None) event_device_id: str | None = getattr(e.internal_metadata, "device_id", None)
if event_device_id is not None: if event_device_id is not None:
if event_device_id == config.requester.device_id: if event_device_id == config.requester.device_id:
d["unsigned"]["transaction_id"] = txn_id d["unsigned"]["transaction_id"] = txn_id
@ -504,9 +502,7 @@ def serialize_event(
# #
# For guests and appservice users, we can't check the access token ID # For guests and appservice users, we can't check the access token ID
# so assume it is the same session. # so assume it is the same session.
event_token_id: Optional[int] = getattr( event_token_id: int | None = getattr(e.internal_metadata, "token_id", None)
e.internal_metadata, "token_id", None
)
if ( if (
( (
event_token_id is not None event_token_id is not None
@ -577,11 +573,11 @@ class EventClientSerializer:
async def serialize_event( async def serialize_event(
self, self,
event: Union[JsonDict, EventBase], event: JsonDict | EventBase,
time_now: int, time_now: int,
*, *,
config: SerializeEventConfig = _DEFAULT_SERIALIZE_EVENT_CONFIG, config: SerializeEventConfig = _DEFAULT_SERIALIZE_EVENT_CONFIG,
bundle_aggregations: Optional[dict[str, "BundledAggregations"]] = None, bundle_aggregations: dict[str, "BundledAggregations"] | None = None,
) -> JsonDict: ) -> JsonDict:
"""Serializes a single event. """Serializes a single event.
@ -712,11 +708,11 @@ class EventClientSerializer:
@trace @trace
async def serialize_events( async def serialize_events(
self, self,
events: Collection[Union[JsonDict, EventBase]], events: Collection[JsonDict | EventBase],
time_now: int, time_now: int,
*, *,
config: SerializeEventConfig = _DEFAULT_SERIALIZE_EVENT_CONFIG, config: SerializeEventConfig = _DEFAULT_SERIALIZE_EVENT_CONFIG,
bundle_aggregations: Optional[dict[str, "BundledAggregations"]] = None, bundle_aggregations: dict[str, "BundledAggregations"] | None = None,
) -> list[JsonDict]: ) -> list[JsonDict]:
"""Serializes multiple events. """Serializes multiple events.
@ -755,13 +751,13 @@ class EventClientSerializer:
self._add_extra_fields_to_unsigned_client_event_callbacks.append(callback) self._add_extra_fields_to_unsigned_client_event_callbacks.append(callback)
_PowerLevel = Union[str, int] _PowerLevel = str | int
PowerLevelsContent = Mapping[str, Union[_PowerLevel, Mapping[str, _PowerLevel]]] PowerLevelsContent = Mapping[str, _PowerLevel | Mapping[str, _PowerLevel]]
def copy_and_fixup_power_levels_contents( def copy_and_fixup_power_levels_contents(
old_power_levels: PowerLevelsContent, old_power_levels: PowerLevelsContent,
) -> dict[str, Union[int, dict[str, int]]]: ) -> dict[str, int | dict[str, int]]:
"""Copy the content of a power_levels event, unfreezing immutabledicts along the way. """Copy the content of a power_levels event, unfreezing immutabledicts along the way.
We accept as input power level values which are strings, provided they represent an We accept as input power level values which are strings, provided they represent an
@ -777,7 +773,7 @@ def copy_and_fixup_power_levels_contents(
if not isinstance(old_power_levels, collections.abc.Mapping): if not isinstance(old_power_levels, collections.abc.Mapping):
raise TypeError("Not a valid power-levels content: %r" % (old_power_levels,)) raise TypeError("Not a valid power-levels content: %r" % (old_power_levels,))
power_levels: dict[str, Union[int, dict[str, int]]] = {} power_levels: dict[str, int | dict[str, int]] = {}
for k, v in old_power_levels.items(): for k, v in old_power_levels.items():
if isinstance(v, collections.abc.Mapping): if isinstance(v, collections.abc.Mapping):
@ -901,7 +897,7 @@ def strip_event(event: EventBase) -> JsonDict:
} }
def parse_stripped_state_event(raw_stripped_event: Any) -> Optional[StrippedStateEvent]: def parse_stripped_state_event(raw_stripped_event: Any) -> StrippedStateEvent | None:
""" """
Given a raw value from an event's `unsigned` field, attempt to parse it into a Given a raw value from an event's `unsigned` field, attempt to parse it into a
`StrippedStateEvent`. `StrippedStateEvent`.

View File

@ -19,7 +19,7 @@
# #
# #
import collections.abc import collections.abc
from typing import Union, cast from typing import cast
import jsonschema import jsonschema
from pydantic import Field, StrictBool, StrictStr from pydantic import Field, StrictBool, StrictStr
@ -177,7 +177,7 @@ class EventValidator:
errcode=Codes.BAD_JSON, errcode=Codes.BAD_JSON,
) )
def validate_builder(self, event: Union[EventBase, EventBuilder]) -> None: def validate_builder(self, event: EventBase | EventBuilder) -> None:
"""Validates that the builder/event has roughly the right format. Only """Validates that the builder/event has roughly the right format. Only
checks values that we expect a proto event to have, rather than all the checks values that we expect a proto event to have, rather than all the
fields an event would have fields an event would have
@ -249,7 +249,7 @@ class EventValidator:
if not isinstance(d[s], str): if not isinstance(d[s], str):
raise SynapseError(400, "'%s' not a string type" % (s,)) raise SynapseError(400, "'%s' not a string type" % (s,))
def _ensure_state_event(self, event: Union[EventBase, EventBuilder]) -> None: def _ensure_state_event(self, event: EventBase | EventBuilder) -> None:
if not event.is_state(): if not event.is_state():
raise SynapseError(400, "'%s' must be state events" % (event.type,)) raise SynapseError(400, "'%s' must be state events" % (event.type,))

View File

@ -20,7 +20,7 @@
# #
# #
import logging import logging
from typing import TYPE_CHECKING, Awaitable, Callable, Optional, Sequence from typing import TYPE_CHECKING, Awaitable, Callable, Sequence
from synapse.api.constants import MAX_DEPTH, EventContentFields, EventTypes, Membership from synapse.api.constants import MAX_DEPTH, EventContentFields, EventTypes, Membership
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
@ -67,7 +67,7 @@ class FederationBase:
# We need to define this lazily otherwise we get a cyclic dependency. # We need to define this lazily otherwise we get a cyclic dependency.
# self._policy_handler = hs.get_room_policy_handler() # self._policy_handler = hs.get_room_policy_handler()
self._policy_handler: Optional[RoomPolicyHandler] = None self._policy_handler: RoomPolicyHandler | None = None
def _lazily_get_policy_handler(self) -> RoomPolicyHandler: def _lazily_get_policy_handler(self) -> RoomPolicyHandler:
"""Lazily get the room policy handler. """Lazily get the room policy handler.
@ -88,9 +88,8 @@ class FederationBase:
self, self,
room_version: RoomVersion, room_version: RoomVersion,
pdu: EventBase, pdu: EventBase,
record_failure_callback: Optional[ record_failure_callback: Callable[[EventBase, str], Awaitable[None]]
Callable[[EventBase, str], Awaitable[None]] | None = None,
] = None,
) -> EventBase: ) -> EventBase:
"""Checks that event is correctly signed by the sending server. """Checks that event is correctly signed by the sending server.

View File

@ -37,7 +37,6 @@ from typing import (
Optional, Optional,
Sequence, Sequence,
TypeVar, TypeVar,
Union,
) )
import attr import attr
@ -263,7 +262,7 @@ class FederationClient(FederationBase):
user: UserID, user: UserID,
destination: str, destination: str,
query: dict[str, dict[str, dict[str, int]]], query: dict[str, dict[str, dict[str, int]]],
timeout: Optional[int], timeout: int | None,
) -> JsonDict: ) -> JsonDict:
"""Claims one-time keys for a device hosted on a remote server. """Claims one-time keys for a device hosted on a remote server.
@ -334,7 +333,7 @@ class FederationClient(FederationBase):
@tag_args @tag_args
async def backfill( async def backfill(
self, dest: str, room_id: str, limit: int, extremities: Collection[str] self, dest: str, room_id: str, limit: int, extremities: Collection[str]
) -> Optional[list[EventBase]]: ) -> list[EventBase] | None:
"""Requests some more historic PDUs for the given room from the """Requests some more historic PDUs for the given room from the
given destination server. given destination server.
@ -381,8 +380,8 @@ class FederationClient(FederationBase):
destination: str, destination: str,
event_id: str, event_id: str,
room_version: RoomVersion, room_version: RoomVersion,
timeout: Optional[int] = None, timeout: int | None = None,
) -> Optional[EventBase]: ) -> EventBase | None:
"""Requests the PDU with given origin and ID from the remote home """Requests the PDU with given origin and ID from the remote home
server. Does not have any caching or rate limiting! server. Does not have any caching or rate limiting!
@ -441,7 +440,7 @@ class FederationClient(FederationBase):
@trace @trace
@tag_args @tag_args
async def get_pdu_policy_recommendation( async def get_pdu_policy_recommendation(
self, destination: str, pdu: EventBase, timeout: Optional[int] = None self, destination: str, pdu: EventBase, timeout: int | None = None
) -> str: ) -> str:
"""Requests that the destination server (typically a policy server) """Requests that the destination server (typically a policy server)
check the event and return its recommendation on how to handle the check the event and return its recommendation on how to handle the
@ -497,8 +496,8 @@ class FederationClient(FederationBase):
@trace @trace
@tag_args @tag_args
async def ask_policy_server_to_sign_event( async def ask_policy_server_to_sign_event(
self, destination: str, pdu: EventBase, timeout: Optional[int] = None self, destination: str, pdu: EventBase, timeout: int | None = None
) -> Optional[JsonDict]: ) -> JsonDict | None:
"""Requests that the destination server (typically a policy server) """Requests that the destination server (typically a policy server)
sign the event as not spam. sign the event as not spam.
@ -538,8 +537,8 @@ class FederationClient(FederationBase):
destinations: Collection[str], destinations: Collection[str],
event_id: str, event_id: str,
room_version: RoomVersion, room_version: RoomVersion,
timeout: Optional[int] = None, timeout: int | None = None,
) -> Optional[PulledPduInfo]: ) -> PulledPduInfo | None:
"""Requests the PDU with given origin and ID from the remote home """Requests the PDU with given origin and ID from the remote home
servers. servers.
@ -832,10 +831,9 @@ class FederationClient(FederationBase):
pdu: EventBase, pdu: EventBase,
origin: str, origin: str,
room_version: RoomVersion, room_version: RoomVersion,
record_failure_callback: Optional[ record_failure_callback: Callable[[EventBase, str], Awaitable[None]]
Callable[[EventBase, str], Awaitable[None]] | None = None,
] = None, ) -> EventBase | None:
) -> Optional[EventBase]:
"""Takes a PDU and checks its signatures and hashes. """Takes a PDU and checks its signatures and hashes.
If the PDU fails its signature check then we check if we have it in the If the PDU fails its signature check then we check if we have it in the
@ -931,7 +929,7 @@ class FederationClient(FederationBase):
description: str, description: str,
destinations: Iterable[str], destinations: Iterable[str],
callback: Callable[[str], Awaitable[T]], callback: Callable[[str], Awaitable[T]],
failover_errcodes: Optional[Container[str]] = None, failover_errcodes: Container[str] | None = None,
failover_on_unknown_endpoint: bool = False, failover_on_unknown_endpoint: bool = False,
) -> T: ) -> T:
"""Try an operation on a series of servers, until it succeeds """Try an operation on a series of servers, until it succeeds
@ -1046,7 +1044,7 @@ class FederationClient(FederationBase):
user_id: str, user_id: str,
membership: str, membership: str,
content: dict, content: dict,
params: Optional[Mapping[str, Union[str, Iterable[str]]]], params: Mapping[str, str | Iterable[str]] | None,
) -> tuple[str, EventBase, RoomVersion]: ) -> tuple[str, EventBase, RoomVersion]:
""" """
Creates an m.room.member event, with context, without participating in the room. Creates an m.room.member event, with context, without participating in the room.
@ -1563,11 +1561,11 @@ class FederationClient(FederationBase):
async def get_public_rooms( async def get_public_rooms(
self, self,
remote_server: str, remote_server: str,
limit: Optional[int] = None, limit: int | None = None,
since_token: Optional[str] = None, since_token: str | None = None,
search_filter: Optional[dict] = None, search_filter: dict | None = None,
include_all_networks: bool = False, include_all_networks: bool = False,
third_party_instance_id: Optional[str] = None, third_party_instance_id: str | None = None,
) -> JsonDict: ) -> JsonDict:
"""Get the list of public rooms from a remote homeserver """Get the list of public rooms from a remote homeserver
@ -1676,7 +1674,7 @@ class FederationClient(FederationBase):
async def get_room_complexity( async def get_room_complexity(
self, destination: str, room_id: str self, destination: str, room_id: str
) -> Optional[JsonDict]: ) -> JsonDict | None:
""" """
Fetch the complexity of a remote room from another server. Fetch the complexity of a remote room from another server.
@ -1987,10 +1985,10 @@ class FederationClient(FederationBase):
max_timeout_ms: int, max_timeout_ms: int,
download_ratelimiter: Ratelimiter, download_ratelimiter: Ratelimiter,
ip_address: str, ip_address: str,
) -> Union[ ) -> (
tuple[int, dict[bytes, list[bytes]], bytes], tuple[int, dict[bytes, list[bytes]], bytes]
tuple[int, dict[bytes, list[bytes]]], | tuple[int, dict[bytes, list[bytes]]]
]: ):
try: try:
return await self.transport_layer.federation_download_media( return await self.transport_layer.federation_download_media(
destination, destination,

View File

@ -28,8 +28,6 @@ from typing import (
Callable, Callable,
Collection, Collection,
Mapping, Mapping,
Optional,
Union,
) )
from prometheus_client import Counter, Gauge, Histogram from prometheus_client import Counter, Gauge, Histogram
@ -176,13 +174,11 @@ class FederationServer(FederationBase):
# We cache responses to state queries, as they take a while and often # We cache responses to state queries, as they take a while and often
# come in waves. # come in waves.
self._state_resp_cache: ResponseCache[tuple[str, Optional[str]]] = ( self._state_resp_cache: ResponseCache[tuple[str, str | None]] = ResponseCache(
ResponseCache( clock=hs.get_clock(),
clock=hs.get_clock(), name="state_resp",
name="state_resp", server_name=self.server_name,
server_name=self.server_name, timeout_ms=30000,
timeout_ms=30000,
)
) )
self._state_ids_resp_cache: ResponseCache[tuple[str, str]] = ResponseCache( self._state_ids_resp_cache: ResponseCache[tuple[str, str]] = ResponseCache(
clock=hs.get_clock(), clock=hs.get_clock(),
@ -666,7 +662,7 @@ class FederationServer(FederationBase):
async def on_pdu_request( async def on_pdu_request(
self, origin: str, event_id: str self, origin: str, event_id: str
) -> tuple[int, Union[JsonDict, str]]: ) -> tuple[int, JsonDict | str]:
pdu = await self.handler.get_persisted_pdu(origin, event_id) pdu = await self.handler.get_persisted_pdu(origin, event_id)
if pdu: if pdu:
@ -763,7 +759,7 @@ class FederationServer(FederationBase):
prev_state_ids = await context.get_prev_state_ids() prev_state_ids = await context.get_prev_state_ids()
state_event_ids: Collection[str] state_event_ids: Collection[str]
servers_in_room: Optional[Collection[str]] servers_in_room: Collection[str] | None
if caller_supports_partial_state: if caller_supports_partial_state:
summary = await self.store.get_room_summary(room_id) summary = await self.store.get_room_summary(room_id)
state_event_ids = _get_event_ids_for_partial_state_join( state_event_ids = _get_event_ids_for_partial_state_join(
@ -1126,7 +1122,7 @@ class FederationServer(FederationBase):
return {"events": serialize_and_filter_pdus(missing_events, time_now)} return {"events": serialize_and_filter_pdus(missing_events, time_now)}
async def on_openid_userinfo(self, token: str) -> Optional[str]: async def on_openid_userinfo(self, token: str) -> str | None:
ts_now_ms = self._clock.time_msec() ts_now_ms = self._clock.time_msec()
return await self.store.get_user_id_for_open_id_token(token, ts_now_ms) return await self.store.get_user_id_for_open_id_token(token, ts_now_ms)
@ -1205,7 +1201,7 @@ class FederationServer(FederationBase):
async def _get_next_nonspam_staged_event_for_room( async def _get_next_nonspam_staged_event_for_room(
self, room_id: str, room_version: RoomVersion self, room_id: str, room_version: RoomVersion
) -> Optional[tuple[str, EventBase]]: ) -> tuple[str, EventBase] | None:
"""Fetch the first non-spam event from staging queue. """Fetch the first non-spam event from staging queue.
Args: Args:
@ -1246,8 +1242,8 @@ class FederationServer(FederationBase):
room_id: str, room_id: str,
room_version: RoomVersion, room_version: RoomVersion,
lock: Lock, lock: Lock,
latest_origin: Optional[str] = None, latest_origin: str | None = None,
latest_event: Optional[EventBase] = None, latest_event: EventBase | None = None,
) -> None: ) -> None:
"""Process events in the staging area for the given room. """Process events in the staging area for the given room.

View File

@ -27,7 +27,6 @@ These actions are mostly only used by the :py:mod:`.replication` module.
""" """
import logging import logging
from typing import Optional
from synapse.federation.units import Transaction from synapse.federation.units import Transaction
from synapse.storage.databases.main import DataStore from synapse.storage.databases.main import DataStore
@ -44,7 +43,7 @@ class TransactionActions:
async def have_responded( async def have_responded(
self, origin: str, transaction: Transaction self, origin: str, transaction: Transaction
) -> Optional[tuple[int, JsonDict]]: ) -> tuple[int, JsonDict] | None:
"""Have we already responded to a transaction with the same id and """Have we already responded to a transaction with the same id and
origin? origin?

View File

@ -42,7 +42,6 @@ from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Hashable, Hashable,
Iterable, Iterable,
Optional,
Sized, Sized,
) )
@ -217,7 +216,7 @@ class FederationRemoteSendQueue(AbstractFederationSender):
destination: str, destination: str,
edu_type: str, edu_type: str,
content: JsonDict, content: JsonDict,
key: Optional[Hashable] = None, key: Hashable | None = None,
) -> None: ) -> None:
"""As per FederationSender""" """As per FederationSender"""
if self.is_mine_server_name(destination): if self.is_mine_server_name(destination):

View File

@ -138,7 +138,6 @@ from typing import (
Hashable, Hashable,
Iterable, Iterable,
Literal, Literal,
Optional,
) )
import attr import attr
@ -266,7 +265,7 @@ class AbstractFederationSender(metaclass=abc.ABCMeta):
destination: str, destination: str,
edu_type: str, edu_type: str,
content: JsonDict, content: JsonDict,
key: Optional[Hashable] = None, key: Hashable | None = None,
) -> None: ) -> None:
"""Construct an Edu object, and queue it for sending """Construct an Edu object, and queue it for sending
@ -410,7 +409,7 @@ class FederationSender(AbstractFederationSender):
self.is_mine_id = hs.is_mine_id self.is_mine_id = hs.is_mine_id
self.is_mine_server_name = hs.is_mine_server_name self.is_mine_server_name = hs.is_mine_server_name
self._presence_router: Optional["PresenceRouter"] = None self._presence_router: "PresenceRouter" | None = None
self._transaction_manager = TransactionManager(hs) self._transaction_manager = TransactionManager(hs)
self._instance_name = hs.get_instance_name() self._instance_name = hs.get_instance_name()
@ -481,7 +480,7 @@ class FederationSender(AbstractFederationSender):
def _get_per_destination_queue( def _get_per_destination_queue(
self, destination: str self, destination: str
) -> Optional[PerDestinationQueue]: ) -> PerDestinationQueue | None:
"""Get or create a PerDestinationQueue for the given destination """Get or create a PerDestinationQueue for the given destination
Args: Args:
@ -605,7 +604,7 @@ class FederationSender(AbstractFederationSender):
) )
return return
destinations: Optional[Collection[str]] = None destinations: Collection[str] | None = None
if not event.prev_event_ids(): if not event.prev_event_ids():
# If there are no prev event IDs then the state is empty # If there are no prev event IDs then the state is empty
# and so no remote servers in the room # and so no remote servers in the room
@ -1010,7 +1009,7 @@ class FederationSender(AbstractFederationSender):
destination: str, destination: str,
edu_type: str, edu_type: str,
content: JsonDict, content: JsonDict,
key: Optional[Hashable] = None, key: Hashable | None = None,
) -> None: ) -> None:
"""Construct an Edu object, and queue it for sending """Construct an Edu object, and queue it for sending
@ -1038,7 +1037,7 @@ class FederationSender(AbstractFederationSender):
self.send_edu(edu, key) self.send_edu(edu, key)
def send_edu(self, edu: Edu, key: Optional[Hashable]) -> None: def send_edu(self, edu: Edu, key: Hashable | None) -> None:
"""Queue an EDU for sending """Queue an EDU for sending
Args: Args:
@ -1134,7 +1133,7 @@ class FederationSender(AbstractFederationSender):
In order to reduce load spikes, adds a delay between each destination. In order to reduce load spikes, adds a delay between each destination.
""" """
last_processed: Optional[str] = None last_processed: str | None = None
while not self._is_shutdown: while not self._is_shutdown:
destinations_to_wake = ( destinations_to_wake = (

View File

@ -23,7 +23,7 @@ import datetime
import logging import logging
from collections import OrderedDict from collections import OrderedDict
from types import TracebackType from types import TracebackType
from typing import TYPE_CHECKING, Hashable, Iterable, Optional from typing import TYPE_CHECKING, Hashable, Iterable
import attr import attr
from prometheus_client import Counter from prometheus_client import Counter
@ -121,7 +121,7 @@ class PerDestinationQueue:
self._destination = destination self._destination = destination
self.transmission_loop_running = False self.transmission_loop_running = False
self._transmission_loop_enabled = True self._transmission_loop_enabled = True
self.active_transmission_loop: Optional[defer.Deferred] = None self.active_transmission_loop: defer.Deferred | None = None
# Flag to signal to any running transmission loop that there is new data # Flag to signal to any running transmission loop that there is new data
# queued up to be sent. # queued up to be sent.
@ -142,7 +142,7 @@ class PerDestinationQueue:
# Cache of the last successfully-transmitted stream ordering for this # Cache of the last successfully-transmitted stream ordering for this
# destination (we are the only updater so this is safe) # destination (we are the only updater so this is safe)
self._last_successful_stream_ordering: Optional[int] = None self._last_successful_stream_ordering: int | None = None
# a queue of pending PDUs # a queue of pending PDUs
self._pending_pdus: list[EventBase] = [] self._pending_pdus: list[EventBase] = []
@ -742,9 +742,9 @@ class _TransactionQueueManager:
queue: PerDestinationQueue queue: PerDestinationQueue
_device_stream_id: Optional[int] = None _device_stream_id: int | None = None
_device_list_id: Optional[int] = None _device_list_id: int | None = None
_last_stream_ordering: Optional[int] = None _last_stream_ordering: int | None = None
_pdus: list[EventBase] = attr.Factory(list) _pdus: list[EventBase] = attr.Factory(list)
async def __aenter__(self) -> tuple[list[EventBase], list[Edu]]: async def __aenter__(self) -> tuple[list[EventBase], list[Edu]]:
@ -845,9 +845,9 @@ class _TransactionQueueManager:
async def __aexit__( async def __aexit__(
self, self,
exc_type: Optional[type[BaseException]], exc_type: type[BaseException] | None,
exc: Optional[BaseException], exc: BaseException | None,
tb: Optional[TracebackType], tb: TracebackType | None,
) -> None: ) -> None:
if exc_type is not None: if exc_type is not None:
# Failed to send transaction, so we bail out. # Failed to send transaction, so we bail out.

View File

@ -31,8 +31,6 @@ from typing import (
Generator, Generator,
Iterable, Iterable,
Mapping, Mapping,
Optional,
Union,
) )
import attr import attr
@ -122,7 +120,7 @@ class TransportLayerClient:
) )
async def get_event( async def get_event(
self, destination: str, event_id: str, timeout: Optional[int] = None self, destination: str, event_id: str, timeout: int | None = None
) -> JsonDict: ) -> JsonDict:
"""Requests the pdu with give id and origin from the given server. """Requests the pdu with give id and origin from the given server.
@ -144,7 +142,7 @@ class TransportLayerClient:
) )
async def get_policy_recommendation_for_pdu( async def get_policy_recommendation_for_pdu(
self, destination: str, event: EventBase, timeout: Optional[int] = None self, destination: str, event: EventBase, timeout: int | None = None
) -> JsonDict: ) -> JsonDict:
"""Requests the policy recommendation for the given pdu from the given policy server. """Requests the policy recommendation for the given pdu from the given policy server.
@ -171,7 +169,7 @@ class TransportLayerClient:
) )
async def ask_policy_server_to_sign_event( async def ask_policy_server_to_sign_event(
self, destination: str, event: EventBase, timeout: Optional[int] = None self, destination: str, event: EventBase, timeout: int | None = None
) -> JsonDict: ) -> JsonDict:
"""Requests that the destination server (typically a policy server) """Requests that the destination server (typically a policy server)
sign the event as not spam. sign the event as not spam.
@ -198,7 +196,7 @@ class TransportLayerClient:
async def backfill( async def backfill(
self, destination: str, room_id: str, event_tuples: Collection[str], limit: int self, destination: str, room_id: str, event_tuples: Collection[str], limit: int
) -> Optional[Union[JsonDict, list]]: ) -> JsonDict | list | None:
"""Requests `limit` previous PDUs in a given context before list of """Requests `limit` previous PDUs in a given context before list of
PDUs. PDUs.
@ -235,7 +233,7 @@ class TransportLayerClient:
async def timestamp_to_event( async def timestamp_to_event(
self, destination: str, room_id: str, timestamp: int, direction: Direction self, destination: str, room_id: str, timestamp: int, direction: Direction
) -> Union[JsonDict, list]: ) -> JsonDict | list:
""" """
Calls a remote federating server at `destination` asking for their Calls a remote federating server at `destination` asking for their
closest event to the given timestamp in the given direction. closest event to the given timestamp in the given direction.
@ -270,7 +268,7 @@ class TransportLayerClient:
async def send_transaction( async def send_transaction(
self, self,
transaction: Transaction, transaction: Transaction,
json_data_callback: Optional[Callable[[], JsonDict]] = None, json_data_callback: Callable[[], JsonDict] | None = None,
) -> JsonDict: ) -> JsonDict:
"""Sends the given Transaction to its destination """Sends the given Transaction to its destination
@ -343,7 +341,7 @@ class TransportLayerClient:
room_id: str, room_id: str,
user_id: str, user_id: str,
membership: str, membership: str,
params: Optional[Mapping[str, Union[str, Iterable[str]]]], params: Mapping[str, str | Iterable[str]] | None,
) -> JsonDict: ) -> JsonDict:
"""Asks a remote server to build and sign us a membership event """Asks a remote server to build and sign us a membership event
@ -528,11 +526,11 @@ class TransportLayerClient:
async def get_public_rooms( async def get_public_rooms(
self, self,
remote_server: str, remote_server: str,
limit: Optional[int] = None, limit: int | None = None,
since_token: Optional[str] = None, since_token: str | None = None,
search_filter: Optional[dict] = None, search_filter: dict | None = None,
include_all_networks: bool = False, include_all_networks: bool = False,
third_party_instance_id: Optional[str] = None, third_party_instance_id: str | None = None,
) -> JsonDict: ) -> JsonDict:
"""Get the list of public rooms from a remote homeserver """Get the list of public rooms from a remote homeserver
@ -567,7 +565,7 @@ class TransportLayerClient:
) )
raise raise
else: else:
args: dict[str, Union[str, Iterable[str]]] = { args: dict[str, str | Iterable[str]] = {
"include_all_networks": "true" if include_all_networks else "false" "include_all_networks": "true" if include_all_networks else "false"
} }
if third_party_instance_id: if third_party_instance_id:
@ -694,7 +692,7 @@ class TransportLayerClient:
user: UserID, user: UserID,
destination: str, destination: str,
query_content: JsonDict, query_content: JsonDict,
timeout: Optional[int], timeout: int | None,
) -> JsonDict: ) -> JsonDict:
"""Claim one-time keys for a list of devices hosted on a remote server. """Claim one-time keys for a list of devices hosted on a remote server.
@ -740,7 +738,7 @@ class TransportLayerClient:
user: UserID, user: UserID,
destination: str, destination: str,
query_content: JsonDict, query_content: JsonDict,
timeout: Optional[int], timeout: int | None,
) -> JsonDict: ) -> JsonDict:
"""Claim one-time keys for a list of devices hosted on a remote server. """Claim one-time keys for a list of devices hosted on a remote server.
@ -997,13 +995,13 @@ class SendJoinResponse:
event_dict: JsonDict event_dict: JsonDict
# The parsed join event from the /send_join response. This will be None if # The parsed join event from the /send_join response. This will be None if
# "event" is not included in the response. # "event" is not included in the response.
event: Optional[EventBase] = None event: EventBase | None = None
# The room state is incomplete # The room state is incomplete
members_omitted: bool = False members_omitted: bool = False
# List of servers in the room # List of servers in the room
servers_in_room: Optional[list[str]] = None servers_in_room: list[str] | None = None
@attr.s(slots=True, auto_attribs=True) @attr.s(slots=True, auto_attribs=True)

View File

@ -20,7 +20,7 @@
# #
# #
import logging import logging
from typing import TYPE_CHECKING, Iterable, Literal, Optional from typing import TYPE_CHECKING, Iterable, Literal
from synapse.api.errors import FederationDeniedError, SynapseError from synapse.api.errors import FederationDeniedError, SynapseError
from synapse.federation.transport.server._base import ( from synapse.federation.transport.server._base import (
@ -52,7 +52,7 @@ logger = logging.getLogger(__name__)
class TransportLayerServer(JsonResource): class TransportLayerServer(JsonResource):
"""Handles incoming federation HTTP requests""" """Handles incoming federation HTTP requests"""
def __init__(self, hs: "HomeServer", servlet_groups: Optional[list[str]] = None): def __init__(self, hs: "HomeServer", servlet_groups: list[str] | None = None):
"""Initialize the TransportLayerServer """Initialize the TransportLayerServer
Will by default register all servlets. For custom behaviour, pass in Will by default register all servlets. For custom behaviour, pass in
@ -135,7 +135,7 @@ class PublicRoomList(BaseFederationServlet):
if not self.allow_access: if not self.allow_access:
raise FederationDeniedError(origin) raise FederationDeniedError(origin)
limit: Optional[int] = parse_integer_from_args(query, "limit", 0) limit: int | None = parse_integer_from_args(query, "limit", 0)
since_token = parse_string_from_args(query, "since", None) since_token = parse_string_from_args(query, "since", None)
include_all_networks = parse_boolean_from_args( include_all_networks = parse_boolean_from_args(
query, "include_all_networks", default=False query, "include_all_networks", default=False
@ -170,7 +170,7 @@ class PublicRoomList(BaseFederationServlet):
if not self.allow_access: if not self.allow_access:
raise FederationDeniedError(origin) raise FederationDeniedError(origin)
limit: Optional[int] = int(content.get("limit", 100)) limit: int | None = int(content.get("limit", 100))
since_token = content.get("since", None) since_token = content.get("since", None)
search_filter = content.get("filter", None) search_filter = content.get("filter", None)
@ -240,7 +240,7 @@ class OpenIdUserInfo(BaseFederationServlet):
async def on_GET( async def on_GET(
self, self,
origin: Optional[str], origin: str | None,
content: Literal[None], content: Literal[None],
query: dict[bytes, list[bytes]], query: dict[bytes, list[bytes]],
) -> tuple[int, JsonDict]: ) -> tuple[int, JsonDict]:
@ -281,7 +281,7 @@ def register_servlets(
resource: HttpServer, resource: HttpServer,
authenticator: Authenticator, authenticator: Authenticator,
ratelimiter: FederationRateLimiter, ratelimiter: FederationRateLimiter,
servlet_groups: Optional[Iterable[str]] = None, servlet_groups: Iterable[str] | None = None,
) -> None: ) -> None:
"""Initialize and register servlet classes. """Initialize and register servlet classes.

View File

@ -24,7 +24,7 @@ import logging
import re import re
import time import time
from http import HTTPStatus from http import HTTPStatus
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, cast from typing import TYPE_CHECKING, Any, Awaitable, Callable, cast
from synapse.api.errors import Codes, FederationDeniedError, SynapseError from synapse.api.errors import Codes, FederationDeniedError, SynapseError
from synapse.api.urls import FEDERATION_V1_PREFIX from synapse.api.urls import FEDERATION_V1_PREFIX
@ -77,7 +77,7 @@ class Authenticator:
# A method just so we can pass 'self' as the authenticator to the Servlets # A method just so we can pass 'self' as the authenticator to the Servlets
async def authenticate_request( async def authenticate_request(
self, request: SynapseRequest, content: Optional[JsonDict] self, request: SynapseRequest, content: JsonDict | None
) -> str: ) -> str:
now = self._clock.time_msec() now = self._clock.time_msec()
json_request: JsonDict = { json_request: JsonDict = {
@ -165,7 +165,7 @@ class Authenticator:
logger.exception("Error resetting retry timings on %s", origin) logger.exception("Error resetting retry timings on %s", origin)
def _parse_auth_header(header_bytes: bytes) -> tuple[str, str, str, Optional[str]]: def _parse_auth_header(header_bytes: bytes) -> tuple[str, str, str, str | None]:
"""Parse an X-Matrix auth header """Parse an X-Matrix auth header
Args: Args:
@ -252,7 +252,7 @@ class BaseFederationServlet:
components as specified in the path match regexp. components as specified in the path match regexp.
Returns: Returns:
Optional[tuple[int, object]]: either (response code, response object) to tuple[int, object] | None: either (response code, response object) to
return a JSON response, or None if the request has already been handled. return a JSON response, or None if the request has already been handled.
Raises: Raises:
@ -289,7 +289,7 @@ class BaseFederationServlet:
@functools.wraps(func) @functools.wraps(func)
async def new_func( async def new_func(
request: SynapseRequest, *args: Any, **kwargs: str request: SynapseRequest, *args: Any, **kwargs: str
) -> Optional[tuple[int, Any]]: ) -> tuple[int, Any] | None:
"""A callback which can be passed to HttpServer.RegisterPaths """A callback which can be passed to HttpServer.RegisterPaths
Args: Args:
@ -309,7 +309,7 @@ class BaseFederationServlet:
try: try:
with start_active_span("authenticate_request"): with start_active_span("authenticate_request"):
origin: Optional[str] = await authenticator.authenticate_request( origin: str | None = await authenticator.authenticate_request(
request, content request, content
) )
except NoAuthenticationError: except NoAuthenticationError:

View File

@ -24,9 +24,7 @@ from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Literal, Literal,
Mapping, Mapping,
Optional,
Sequence, Sequence,
Union,
) )
from synapse.api.constants import Direction, EduTypes from synapse.api.constants import Direction, EduTypes
@ -156,7 +154,7 @@ class FederationEventServlet(BaseFederationServerServlet):
content: Literal[None], content: Literal[None],
query: dict[bytes, list[bytes]], query: dict[bytes, list[bytes]],
event_id: str, event_id: str,
) -> tuple[int, Union[JsonDict, str]]: ) -> tuple[int, JsonDict | str]:
return await self.handler.on_pdu_request(origin, event_id) return await self.handler.on_pdu_request(origin, event_id)
@ -642,7 +640,7 @@ class On3pidBindServlet(BaseFederationServerServlet):
REQUIRE_AUTH = False REQUIRE_AUTH = False
async def on_POST( async def on_POST(
self, origin: Optional[str], content: JsonDict, query: dict[bytes, list[bytes]] self, origin: str | None, content: JsonDict, query: dict[bytes, list[bytes]]
) -> tuple[int, JsonDict]: ) -> tuple[int, JsonDict]:
if "invites" in content: if "invites" in content:
last_exception = None last_exception = None
@ -676,7 +674,7 @@ class FederationVersionServlet(BaseFederationServlet):
async def on_GET( async def on_GET(
self, self,
origin: Optional[str], origin: str | None,
content: Literal[None], content: Literal[None],
query: dict[bytes, list[bytes]], query: dict[bytes, list[bytes]],
) -> tuple[int, JsonDict]: ) -> tuple[int, JsonDict]:
@ -812,7 +810,7 @@ class FederationMediaDownloadServlet(BaseFederationServerServlet):
async def on_GET( async def on_GET(
self, self,
origin: Optional[str], origin: str | None,
content: Literal[None], content: Literal[None],
request: SynapseRequest, request: SynapseRequest,
media_id: str, media_id: str,
@ -852,7 +850,7 @@ class FederationMediaThumbnailServlet(BaseFederationServerServlet):
async def on_GET( async def on_GET(
self, self,
origin: Optional[str], origin: str | None,
content: Literal[None], content: Literal[None],
request: SynapseRequest, request: SynapseRequest,
media_id: str, media_id: str,

View File

@ -24,7 +24,7 @@ server protocol.
""" """
import logging import logging
from typing import Optional, Sequence from typing import Sequence
import attr import attr
@ -70,7 +70,7 @@ class Edu:
getattr(self, "content", {})["org.matrix.opentracing_context"] = "{}" getattr(self, "content", {})["org.matrix.opentracing_context"] = "{}"
def _none_to_list(edus: Optional[list[JsonDict]]) -> list[JsonDict]: def _none_to_list(edus: list[JsonDict] | None) -> list[JsonDict]:
if edus is None: if edus is None:
return [] return []
return edus return edus
@ -128,6 +128,6 @@ def filter_pdus_for_valid_depth(pdus: Sequence[JsonDict]) -> list[JsonDict]:
def serialize_and_filter_pdus( def serialize_and_filter_pdus(
pdus: Sequence[EventBase], time_now: Optional[int] = None pdus: Sequence[EventBase], time_now: int | None = None
) -> list[JsonDict]: ) -> list[JsonDict]:
return filter_pdus_for_valid_depth([pdu.get_pdu_json(time_now) for pdu in pdus]) return filter_pdus_for_valid_depth([pdu.get_pdu_json(time_now) for pdu in pdus])

View File

@ -21,7 +21,7 @@
# #
import logging import logging
import random import random
from typing import TYPE_CHECKING, Awaitable, Callable, Optional from typing import TYPE_CHECKING, Awaitable, Callable
from synapse.api.constants import AccountDataTypes from synapse.api.constants import AccountDataTypes
from synapse.replication.http.account_data import ( from synapse.replication.http.account_data import (
@ -40,9 +40,7 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
ON_ACCOUNT_DATA_UPDATED_CALLBACK = Callable[ ON_ACCOUNT_DATA_UPDATED_CALLBACK = Callable[[str, str | None, str, JsonDict], Awaitable]
[str, Optional[str], str, JsonDict], Awaitable
]
class AccountDataHandler: class AccountDataHandler:
@ -72,7 +70,7 @@ class AccountDataHandler:
] = [] ] = []
def register_module_callbacks( def register_module_callbacks(
self, on_account_data_updated: Optional[ON_ACCOUNT_DATA_UPDATED_CALLBACK] = None self, on_account_data_updated: ON_ACCOUNT_DATA_UPDATED_CALLBACK | None = None
) -> None: ) -> None:
"""Register callbacks from modules.""" """Register callbacks from modules."""
if on_account_data_updated is not None: if on_account_data_updated is not None:
@ -81,7 +79,7 @@ class AccountDataHandler:
async def _notify_modules( async def _notify_modules(
self, self,
user_id: str, user_id: str,
room_id: Optional[str], room_id: str | None,
account_data_type: str, account_data_type: str,
content: JsonDict, content: JsonDict,
) -> None: ) -> None:
@ -143,7 +141,7 @@ class AccountDataHandler:
async def remove_account_data_for_room( async def remove_account_data_for_room(
self, user_id: str, room_id: str, account_data_type: str self, user_id: str, room_id: str, account_data_type: str
) -> Optional[int]: ) -> int | None:
""" """
Deletes the room account data for the given user and account data type. Deletes the room account data for the given user and account data type.
@ -219,7 +217,7 @@ class AccountDataHandler:
async def remove_account_data_for_user( async def remove_account_data_for_user(
self, user_id: str, account_data_type: str self, user_id: str, account_data_type: str
) -> Optional[int]: ) -> int | None:
"""Removes a piece of global account_data for a user. """Removes a piece of global account_data for a user.
Args: Args:
@ -324,7 +322,7 @@ class AccountDataEventSource(EventSource[int, JsonDict]):
limit: int, limit: int,
room_ids: StrCollection, room_ids: StrCollection,
is_guest: bool, is_guest: bool,
explicit_room_id: Optional[str] = None, explicit_room_id: str | None = None,
) -> tuple[list[JsonDict], int]: ) -> tuple[list[JsonDict], int]:
user_id = user.to_string() user_id = user.to_string()
last_stream_id = from_key last_stream_id = from_key

View File

@ -21,7 +21,7 @@
import email.mime.multipart import email.mime.multipart
import email.utils import email.utils
import logging import logging
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING
from synapse.api.errors import AuthError, StoreError, SynapseError from synapse.api.errors import AuthError, StoreError, SynapseError
from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.metrics.background_process_metrics import wrap_as_background_process
@ -108,8 +108,8 @@ class AccountValidityHandler:
async def on_user_login( async def on_user_login(
self, self,
user_id: str, user_id: str,
auth_provider_type: Optional[str], auth_provider_type: str | None,
auth_provider_id: Optional[str], auth_provider_id: str | None,
) -> None: ) -> None:
"""Tell third-party modules about a user logins. """Tell third-party modules about a user logins.
@ -326,9 +326,9 @@ class AccountValidityHandler:
async def renew_account_for_user( async def renew_account_for_user(
self, self,
user_id: str, user_id: str,
expiration_ts: Optional[int] = None, expiration_ts: int | None = None,
email_sent: bool = False, email_sent: bool = False,
renewal_token: Optional[str] = None, renewal_token: str | None = None,
) -> int: ) -> int:
"""Renews the account attached to a given user by pushing back the """Renews the account attached to a given user by pushing back the
expiration date by the current validity period in the server's expiration date by the current validity period in the server's

View File

@ -25,7 +25,6 @@ from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
Mapping, Mapping,
Optional,
Sequence, Sequence,
) )
@ -71,7 +70,7 @@ class AdminHandler:
self.hs = hs self.hs = hs
async def get_redact_task(self, redact_id: str) -> Optional[ScheduledTask]: async def get_redact_task(self, redact_id: str) -> ScheduledTask | None:
"""Get the current status of an active redaction process """Get the current status of an active redaction process
Args: Args:
@ -99,11 +98,9 @@ class AdminHandler:
return ret return ret
async def get_user(self, user: UserID) -> Optional[JsonMapping]: async def get_user(self, user: UserID) -> JsonMapping | None:
"""Function to get user details""" """Function to get user details"""
user_info: Optional[UserInfo] = await self._store.get_user_by_id( user_info: UserInfo | None = await self._store.get_user_by_id(user.to_string())
user.to_string()
)
if user_info is None: if user_info is None:
return None return None
@ -355,8 +352,8 @@ class AdminHandler:
rooms: list, rooms: list,
requester: JsonMapping, requester: JsonMapping,
use_admin: bool, use_admin: bool,
reason: Optional[str], reason: str | None,
limit: Optional[int], limit: int | None,
) -> str: ) -> str:
""" """
Start a task redacting the events of the given user in the given rooms Start a task redacting the events of the given user in the given rooms
@ -408,7 +405,7 @@ class AdminHandler:
async def _redact_all_events( async def _redact_all_events(
self, task: ScheduledTask self, task: ScheduledTask
) -> tuple[TaskStatus, Optional[Mapping[str, Any]], Optional[str]]: ) -> tuple[TaskStatus, Mapping[str, Any] | None, str | None]:
""" """
Task to redact all of a users events in the given rooms, tracking which, if any, events Task to redact all of a users events in the given rooms, tracking which, if any, events
whose redaction failed whose redaction failed

View File

@ -24,8 +24,6 @@ from typing import (
Collection, Collection,
Iterable, Iterable,
Mapping, Mapping,
Optional,
Union,
) )
from prometheus_client import Counter from prometheus_client import Counter
@ -240,8 +238,8 @@ class ApplicationServicesHandler:
def notify_interested_services_ephemeral( def notify_interested_services_ephemeral(
self, self,
stream_key: StreamKeyType, stream_key: StreamKeyType,
new_token: Union[int, RoomStreamToken, MultiWriterStreamToken], new_token: int | RoomStreamToken | MultiWriterStreamToken,
users: Collection[Union[str, UserID]], users: Collection[str | UserID],
) -> None: ) -> None:
""" """
This is called by the notifier in the background when an ephemeral event is handled This is called by the notifier in the background when an ephemeral event is handled
@ -340,8 +338,8 @@ class ApplicationServicesHandler:
self, self,
services: list[ApplicationService], services: list[ApplicationService],
stream_key: StreamKeyType, stream_key: StreamKeyType,
new_token: Union[int, MultiWriterStreamToken], new_token: int | MultiWriterStreamToken,
users: Collection[Union[str, UserID]], users: Collection[str | UserID],
) -> None: ) -> None:
logger.debug("Checking interested services for %s", stream_key) logger.debug("Checking interested services for %s", stream_key)
with Measure( with Measure(
@ -498,8 +496,8 @@ class ApplicationServicesHandler:
async def _handle_presence( async def _handle_presence(
self, self,
service: ApplicationService, service: ApplicationService,
users: Collection[Union[str, UserID]], users: Collection[str | UserID],
new_token: Optional[int], new_token: int | None,
) -> list[JsonMapping]: ) -> list[JsonMapping]:
""" """
Return the latest presence updates that the given application service should receive. Return the latest presence updates that the given application service should receive.
@ -559,7 +557,7 @@ class ApplicationServicesHandler:
self, self,
service: ApplicationService, service: ApplicationService,
new_token: int, new_token: int,
users: Collection[Union[str, UserID]], users: Collection[str | UserID],
) -> list[JsonDict]: ) -> list[JsonDict]:
""" """
Given an application service, determine which events it should receive Given an application service, determine which events it should receive
@ -733,7 +731,7 @@ class ApplicationServicesHandler:
async def query_room_alias_exists( async def query_room_alias_exists(
self, room_alias: RoomAlias self, room_alias: RoomAlias
) -> Optional[RoomAliasMapping]: ) -> RoomAliasMapping | None:
"""Check if an application service knows this room alias exists. """Check if an application service knows this room alias exists.
Args: Args:
@ -782,7 +780,7 @@ class ApplicationServicesHandler:
return ret return ret
async def get_3pe_protocols( async def get_3pe_protocols(
self, only_protocol: Optional[str] = None self, only_protocol: str | None = None
) -> dict[str, JsonDict]: ) -> dict[str, JsonDict]:
services = self.store.get_app_services() services = self.store.get_app_services()
protocols: dict[str, list[JsonDict]] = {} protocols: dict[str, list[JsonDict]] = {}
@ -935,7 +933,7 @@ class ApplicationServicesHandler:
return claimed_keys, missing return claimed_keys, missing
async def query_keys( async def query_keys(
self, query: Mapping[str, Optional[list[str]]] self, query: Mapping[str, list[str] | None]
) -> dict[str, dict[str, dict[str, JsonDict]]]: ) -> dict[str, dict[str, dict[str, JsonDict]]]:
"""Query application services for device keys. """Query application services for device keys.

View File

@ -33,8 +33,6 @@ from typing import (
Callable, Callable,
Iterable, Iterable,
Mapping, Mapping,
Optional,
Union,
cast, cast,
) )
@ -289,7 +287,7 @@ class AuthHandler:
request_body: dict[str, Any], request_body: dict[str, Any],
description: str, description: str,
can_skip_ui_auth: bool = False, can_skip_ui_auth: bool = False,
) -> tuple[dict, Optional[str]]: ) -> tuple[dict, str | None]:
""" """
Checks that the user is who they claim to be, via a UI auth. Checks that the user is who they claim to be, via a UI auth.
@ -440,7 +438,7 @@ class AuthHandler:
request: SynapseRequest, request: SynapseRequest,
clientdict: dict[str, Any], clientdict: dict[str, Any],
description: str, description: str,
get_new_session_data: Optional[Callable[[], JsonDict]] = None, get_new_session_data: Callable[[], JsonDict] | None = None,
) -> tuple[dict, dict, str]: ) -> tuple[dict, dict, str]:
""" """
Takes a dictionary sent by the client in the login / registration Takes a dictionary sent by the client in the login / registration
@ -487,7 +485,7 @@ class AuthHandler:
all the stages in any of the permitted flows. all the stages in any of the permitted flows.
""" """
sid: Optional[str] = None sid: str | None = None
authdict = clientdict.pop("auth", {}) authdict = clientdict.pop("auth", {})
if "session" in authdict: if "session" in authdict:
sid = authdict["session"] sid = authdict["session"]
@ -637,7 +635,7 @@ class AuthHandler:
authdict["session"], stagetype, result authdict["session"], stagetype, result
) )
def get_session_id(self, clientdict: dict[str, Any]) -> Optional[str]: def get_session_id(self, clientdict: dict[str, Any]) -> str | None:
""" """
Gets the session ID for a client given the client dictionary Gets the session ID for a client given the client dictionary
@ -673,7 +671,7 @@ class AuthHandler:
raise SynapseError(400, "Unknown session ID: %s" % (session_id,)) raise SynapseError(400, "Unknown session ID: %s" % (session_id,))
async def get_session_data( async def get_session_data(
self, session_id: str, key: str, default: Optional[Any] = None self, session_id: str, key: str, default: Any | None = None
) -> Any: ) -> Any:
""" """
Retrieve data stored with set_session_data Retrieve data stored with set_session_data
@ -699,7 +697,7 @@ class AuthHandler:
async def _check_auth_dict( async def _check_auth_dict(
self, authdict: dict[str, Any], clientip: str self, authdict: dict[str, Any], clientip: str
) -> Union[dict[str, Any], str]: ) -> dict[str, Any] | str:
"""Attempt to validate the auth dict provided by a client """Attempt to validate the auth dict provided by a client
Args: Args:
@ -774,9 +772,9 @@ class AuthHandler:
async def refresh_token( async def refresh_token(
self, self,
refresh_token: str, refresh_token: str,
access_token_valid_until_ms: Optional[int], access_token_valid_until_ms: int | None,
refresh_token_valid_until_ms: Optional[int], refresh_token_valid_until_ms: int | None,
) -> tuple[str, str, Optional[int]]: ) -> tuple[str, str, int | None]:
""" """
Consumes a refresh token and generate both a new access token and a new refresh token from it. Consumes a refresh token and generate both a new access token and a new refresh token from it.
@ -909,8 +907,8 @@ class AuthHandler:
self, self,
user_id: str, user_id: str,
duration_ms: int = (2 * 60 * 1000), duration_ms: int = (2 * 60 * 1000),
auth_provider_id: Optional[str] = None, auth_provider_id: str | None = None,
auth_provider_session_id: Optional[str] = None, auth_provider_session_id: str | None = None,
) -> str: ) -> str:
login_token = self.generate_login_token() login_token = self.generate_login_token()
now = self._clock.time_msec() now = self._clock.time_msec()
@ -928,8 +926,8 @@ class AuthHandler:
self, self,
user_id: str, user_id: str,
device_id: str, device_id: str,
expiry_ts: Optional[int], expiry_ts: int | None,
ultimate_session_expiry_ts: Optional[int], ultimate_session_expiry_ts: int | None,
) -> tuple[str, int]: ) -> tuple[str, int]:
""" """
Creates a new refresh token for the user with the given user ID. Creates a new refresh token for the user with the given user ID.
@ -961,11 +959,11 @@ class AuthHandler:
async def create_access_token_for_user_id( async def create_access_token_for_user_id(
self, self,
user_id: str, user_id: str,
device_id: Optional[str], device_id: str | None,
valid_until_ms: Optional[int], valid_until_ms: int | None,
puppets_user_id: Optional[str] = None, puppets_user_id: str | None = None,
is_appservice_ghost: bool = False, is_appservice_ghost: bool = False,
refresh_token_id: Optional[int] = None, refresh_token_id: int | None = None,
) -> str: ) -> str:
""" """
Creates a new access token for the user with the given user ID. Creates a new access token for the user with the given user ID.
@ -1034,7 +1032,7 @@ class AuthHandler:
return access_token return access_token
async def check_user_exists(self, user_id: str) -> Optional[str]: async def check_user_exists(self, user_id: str) -> str | None:
""" """
Checks to see if a user with the given id exists. Will check case Checks to see if a user with the given id exists. Will check case
insensitively, but return None if there are multiple inexact matches. insensitively, but return None if there are multiple inexact matches.
@ -1061,9 +1059,7 @@ class AuthHandler:
""" """
return await self.store.is_user_approved(user_id) return await self.store.is_user_approved(user_id)
async def _find_user_id_and_pwd_hash( async def _find_user_id_and_pwd_hash(self, user_id: str) -> tuple[str, str] | None:
self, user_id: str
) -> Optional[tuple[str, str]]:
"""Checks to see if a user with the given id exists. Will check case """Checks to see if a user with the given id exists. Will check case
insensitively, but will return None if there are multiple inexact insensitively, but will return None if there are multiple inexact
matches. matches.
@ -1141,7 +1137,7 @@ class AuthHandler:
login_submission: dict[str, Any], login_submission: dict[str, Any],
ratelimit: bool = False, ratelimit: bool = False,
is_reauth: bool = False, is_reauth: bool = False,
) -> tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]: ) -> tuple[str, Callable[["LoginResponse"], Awaitable[None]] | None]:
"""Authenticates the user for the /login API """Authenticates the user for the /login API
Also used by the user-interactive auth flow to validate auth types which don't Also used by the user-interactive auth flow to validate auth types which don't
@ -1297,7 +1293,7 @@ class AuthHandler:
self, self,
username: str, username: str,
login_submission: dict[str, Any], login_submission: dict[str, Any],
) -> tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]: ) -> tuple[str, Callable[["LoginResponse"], Awaitable[None]] | None]:
"""Helper for validate_login """Helper for validate_login
Handles login, once we've mapped 3pids onto userids Handles login, once we've mapped 3pids onto userids
@ -1386,7 +1382,7 @@ class AuthHandler:
async def check_password_provider_3pid( async def check_password_provider_3pid(
self, medium: str, address: str, password: str self, medium: str, address: str, password: str
) -> tuple[Optional[str], Optional[Callable[["LoginResponse"], Awaitable[None]]]]: ) -> tuple[str | None, Callable[["LoginResponse"], Awaitable[None]] | None]:
"""Check if a password provider is able to validate a thirdparty login """Check if a password provider is able to validate a thirdparty login
Args: Args:
@ -1413,7 +1409,7 @@ class AuthHandler:
# if result is None then return (None, None) # if result is None then return (None, None)
return None, None return None, None
async def _check_local_password(self, user_id: str, password: str) -> Optional[str]: async def _check_local_password(self, user_id: str, password: str) -> str | None:
"""Authenticate a user against the local password database. """Authenticate a user against the local password database.
user_id is checked case insensitively, but will return None if there are user_id is checked case insensitively, but will return None if there are
@ -1528,8 +1524,8 @@ class AuthHandler:
async def delete_access_tokens_for_user( async def delete_access_tokens_for_user(
self, self,
user_id: str, user_id: str,
except_token_id: Optional[int] = None, except_token_id: int | None = None,
device_id: Optional[str] = None, device_id: str | None = None,
) -> None: ) -> None:
"""Invalidate access tokens belonging to a user """Invalidate access tokens belonging to a user
@ -1700,9 +1696,7 @@ class AuthHandler:
return await defer_to_thread(self.hs.get_reactor(), _do_hash) return await defer_to_thread(self.hs.get_reactor(), _do_hash)
async def validate_hash( async def validate_hash(self, password: str, stored_hash: bytes | str) -> bool:
self, password: str, stored_hash: Union[bytes, str]
) -> bool:
"""Validates that self.hash(password) == stored_hash. """Validates that self.hash(password) == stored_hash.
Args: Args:
@ -1799,9 +1793,9 @@ class AuthHandler:
auth_provider_id: str, auth_provider_id: str,
request: Request, request: Request,
client_redirect_url: str, client_redirect_url: str,
extra_attributes: Optional[JsonDict] = None, extra_attributes: JsonDict | None = None,
new_user: bool = False, new_user: bool = False,
auth_provider_session_id: Optional[str] = None, auth_provider_session_id: str | None = None,
) -> None: ) -> None:
"""Having figured out a mxid for this user, complete the HTTP request """Having figured out a mxid for this user, complete the HTTP request
@ -1960,7 +1954,7 @@ def load_single_legacy_password_auth_provider(
# All methods that the module provides should be async, but this wasn't enforced # All methods that the module provides should be async, but this wasn't enforced
# in the old module system, so we wrap them if needed # in the old module system, so we wrap them if needed
def async_wrapper(f: Optional[Callable]) -> Optional[Callable[..., Awaitable]]: def async_wrapper(f: Callable | None) -> Callable[..., Awaitable] | None:
# f might be None if the callback isn't implemented by the module. In this # f might be None if the callback isn't implemented by the module. In this
# case we don't want to register a callback at all so we return None. # case we don't want to register a callback at all so we return None.
if f is None: if f is None:
@ -1973,7 +1967,7 @@ def load_single_legacy_password_auth_provider(
async def wrapped_check_password( async def wrapped_check_password(
username: str, login_type: str, login_dict: JsonDict username: str, login_type: str, login_dict: JsonDict
) -> Optional[tuple[str, Optional[Callable]]]: ) -> tuple[str, Callable | None] | None:
# We've already made sure f is not None above, but mypy doesn't do well # We've already made sure f is not None above, but mypy doesn't do well
# across function boundaries so we need to tell it f is definitely not # across function boundaries so we need to tell it f is definitely not
# None. # None.
@ -1992,12 +1986,12 @@ def load_single_legacy_password_auth_provider(
return wrapped_check_password return wrapped_check_password
# We need to wrap check_auth as in the old form it could return # We need to wrap check_auth as in the old form it could return
# just a str, but now it must return Optional[tuple[str, Optional[Callable]] # just a str, but now it must return tuple[str, Callable | None] | None
if f.__name__ == "check_auth": if f.__name__ == "check_auth":
async def wrapped_check_auth( async def wrapped_check_auth(
username: str, login_type: str, login_dict: JsonDict username: str, login_type: str, login_dict: JsonDict
) -> Optional[tuple[str, Optional[Callable]]]: ) -> tuple[str, Callable | None] | None:
# We've already made sure f is not None above, but mypy doesn't do well # We've already made sure f is not None above, but mypy doesn't do well
# across function boundaries so we need to tell it f is definitely not # across function boundaries so we need to tell it f is definitely not
# None. # None.
@ -2013,12 +2007,12 @@ def load_single_legacy_password_auth_provider(
return wrapped_check_auth return wrapped_check_auth
# We need to wrap check_3pid_auth as in the old form it could return # We need to wrap check_3pid_auth as in the old form it could return
# just a str, but now it must return Optional[tuple[str, Optional[Callable]] # just a str, but now it must return tuple[str, Callable | None] | None
if f.__name__ == "check_3pid_auth": if f.__name__ == "check_3pid_auth":
async def wrapped_check_3pid_auth( async def wrapped_check_3pid_auth(
medium: str, address: str, password: str medium: str, address: str, password: str
) -> Optional[tuple[str, Optional[Callable]]]: ) -> tuple[str, Callable | None] | None:
# We've already made sure f is not None above, but mypy doesn't do well # We've already made sure f is not None above, but mypy doesn't do well
# across function boundaries so we need to tell it f is definitely not # across function boundaries so we need to tell it f is definitely not
# None. # None.
@ -2044,10 +2038,10 @@ def load_single_legacy_password_auth_provider(
# If the module has these methods implemented, then we pull them out # If the module has these methods implemented, then we pull them out
# and register them as hooks. # and register them as hooks.
check_3pid_auth_hook: Optional[CHECK_3PID_AUTH_CALLBACK] = async_wrapper( check_3pid_auth_hook: CHECK_3PID_AUTH_CALLBACK | None = async_wrapper(
getattr(provider, "check_3pid_auth", None) getattr(provider, "check_3pid_auth", None)
) )
on_logged_out_hook: Optional[ON_LOGGED_OUT_CALLBACK] = async_wrapper( on_logged_out_hook: ON_LOGGED_OUT_CALLBACK | None = async_wrapper(
getattr(provider, "on_logged_out", None) getattr(provider, "on_logged_out", None)
) )
@ -2085,24 +2079,20 @@ def load_single_legacy_password_auth_provider(
CHECK_3PID_AUTH_CALLBACK = Callable[ CHECK_3PID_AUTH_CALLBACK = Callable[
[str, str, str], [str, str, str],
Awaitable[ Awaitable[tuple[str, Callable[["LoginResponse"], Awaitable[None]] | None] | None],
Optional[tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]]
],
] ]
ON_LOGGED_OUT_CALLBACK = Callable[[str, Optional[str], str], Awaitable] ON_LOGGED_OUT_CALLBACK = Callable[[str, str | None, str], Awaitable]
CHECK_AUTH_CALLBACK = Callable[ CHECK_AUTH_CALLBACK = Callable[
[str, str, JsonDict], [str, str, JsonDict],
Awaitable[ Awaitable[tuple[str, Callable[["LoginResponse"], Awaitable[None]] | None] | None],
Optional[tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]]
],
] ]
GET_USERNAME_FOR_REGISTRATION_CALLBACK = Callable[ GET_USERNAME_FOR_REGISTRATION_CALLBACK = Callable[
[JsonDict, JsonDict], [JsonDict, JsonDict],
Awaitable[Optional[str]], Awaitable[str | None],
] ]
GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK = Callable[ GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK = Callable[
[JsonDict, JsonDict], [JsonDict, JsonDict],
Awaitable[Optional[str]], Awaitable[str | None],
] ]
IS_3PID_ALLOWED_CALLBACK = Callable[[str, str, bool], Awaitable[bool]] IS_3PID_ALLOWED_CALLBACK = Callable[[str, str, bool], Awaitable[bool]]
@ -2133,18 +2123,15 @@ class PasswordAuthProvider:
def register_password_auth_provider_callbacks( def register_password_auth_provider_callbacks(
self, self,
check_3pid_auth: Optional[CHECK_3PID_AUTH_CALLBACK] = None, check_3pid_auth: CHECK_3PID_AUTH_CALLBACK | None = None,
on_logged_out: Optional[ON_LOGGED_OUT_CALLBACK] = None, on_logged_out: ON_LOGGED_OUT_CALLBACK | None = None,
is_3pid_allowed: Optional[IS_3PID_ALLOWED_CALLBACK] = None, is_3pid_allowed: IS_3PID_ALLOWED_CALLBACK | None = None,
auth_checkers: Optional[ auth_checkers: dict[tuple[str, tuple[str, ...]], CHECK_AUTH_CALLBACK]
dict[tuple[str, tuple[str, ...]], CHECK_AUTH_CALLBACK] | None = None,
] = None, get_username_for_registration: GET_USERNAME_FOR_REGISTRATION_CALLBACK
get_username_for_registration: Optional[ | None = None,
GET_USERNAME_FOR_REGISTRATION_CALLBACK get_displayname_for_registration: GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK
] = None, | None = None,
get_displayname_for_registration: Optional[
GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK
] = None,
) -> None: ) -> None:
# Register check_3pid_auth callback # Register check_3pid_auth callback
if check_3pid_auth is not None: if check_3pid_auth is not None:
@ -2214,7 +2201,7 @@ class PasswordAuthProvider:
async def check_auth( async def check_auth(
self, username: str, login_type: str, login_dict: JsonDict self, username: str, login_type: str, login_dict: JsonDict
) -> Optional[tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]]: ) -> tuple[str, Callable[["LoginResponse"], Awaitable[None]] | None] | None:
"""Check if the user has presented valid login credentials """Check if the user has presented valid login credentials
Args: Args:
@ -2245,14 +2232,14 @@ class PasswordAuthProvider:
continue continue
if result is not None: if result is not None:
# Check that the callback returned a Tuple[str, Optional[Callable]] # Check that the callback returned a tuple[str, Callable | None]
# "type: ignore[unreachable]" is used after some isinstance checks because mypy thinks # "type: ignore[unreachable]" is used after some isinstance checks because mypy thinks
# result is always the right type, but as it is 3rd party code it might not be # result is always the right type, but as it is 3rd party code it might not be
if not isinstance(result, tuple) or len(result) != 2: if not isinstance(result, tuple) or len(result) != 2:
logger.warning( # type: ignore[unreachable] logger.warning( # type: ignore[unreachable]
"Wrong type returned by module API callback %s: %s, expected" "Wrong type returned by module API callback %s: %s, expected"
" Optional[tuple[str, Optional[Callable]]]", " tuple[str, Callable | None] | None",
callback, callback,
result, result,
) )
@ -2265,24 +2252,24 @@ class PasswordAuthProvider:
if not isinstance(str_result, str): if not isinstance(str_result, str):
logger.warning( # type: ignore[unreachable] logger.warning( # type: ignore[unreachable]
"Wrong type returned by module API callback %s: %s, expected" "Wrong type returned by module API callback %s: %s, expected"
" Optional[tuple[str, Optional[Callable]]]", " tuple[str, Callable | None] | None",
callback, callback,
result, result,
) )
continue continue
# the second should be Optional[Callable] # the second should be Callable | None
if callback_result is not None: if callback_result is not None:
if not callable(callback_result): if not callable(callback_result):
logger.warning( # type: ignore[unreachable] logger.warning( # type: ignore[unreachable]
"Wrong type returned by module API callback %s: %s, expected" "Wrong type returned by module API callback %s: %s, expected"
" Optional[tuple[str, Optional[Callable]]]", " tuple[str, Callable | None] | None",
callback, callback,
result, result,
) )
continue continue
# The result is a (str, Optional[callback]) tuple so return the successful result # The result is a (str, callback | None) tuple so return the successful result
return result return result
# If this point has been reached then none of the callbacks successfully authenticated # If this point has been reached then none of the callbacks successfully authenticated
@ -2291,7 +2278,7 @@ class PasswordAuthProvider:
async def check_3pid_auth( async def check_3pid_auth(
self, medium: str, address: str, password: str self, medium: str, address: str, password: str
) -> Optional[tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]]: ) -> tuple[str, Callable[["LoginResponse"], Awaitable[None]] | None] | None:
# This function is able to return a deferred that either # This function is able to return a deferred that either
# resolves None, meaning authentication failure, or upon # resolves None, meaning authentication failure, or upon
# success, to a str (which is the user_id) or a tuple of # success, to a str (which is the user_id) or a tuple of
@ -2308,14 +2295,14 @@ class PasswordAuthProvider:
continue continue
if result is not None: if result is not None:
# Check that the callback returned a Tuple[str, Optional[Callable]] # Check that the callback returned a tuple[str, Callable | None]
# "type: ignore[unreachable]" is used after some isinstance checks because mypy thinks # "type: ignore[unreachable]" is used after some isinstance checks because mypy thinks
# result is always the right type, but as it is 3rd party code it might not be # result is always the right type, but as it is 3rd party code it might not be
if not isinstance(result, tuple) or len(result) != 2: if not isinstance(result, tuple) or len(result) != 2:
logger.warning( # type: ignore[unreachable] logger.warning( # type: ignore[unreachable]
"Wrong type returned by module API callback %s: %s, expected" "Wrong type returned by module API callback %s: %s, expected"
" Optional[tuple[str, Optional[Callable]]]", " tuple[str, Callable | None] | None",
callback, callback,
result, result,
) )
@ -2328,24 +2315,24 @@ class PasswordAuthProvider:
if not isinstance(str_result, str): if not isinstance(str_result, str):
logger.warning( # type: ignore[unreachable] logger.warning( # type: ignore[unreachable]
"Wrong type returned by module API callback %s: %s, expected" "Wrong type returned by module API callback %s: %s, expected"
" Optional[tuple[str, Optional[Callable]]]", " tuple[str, Callable | None] | None",
callback, callback,
result, result,
) )
continue continue
# the second should be Optional[Callable] # the second should be Callable | None
if callback_result is not None: if callback_result is not None:
if not callable(callback_result): if not callable(callback_result):
logger.warning( # type: ignore[unreachable] logger.warning( # type: ignore[unreachable]
"Wrong type returned by module API callback %s: %s, expected" "Wrong type returned by module API callback %s: %s, expected"
" Optional[tuple[str, Optional[Callable]]]", " tuple[str, Callable | None] | None",
callback, callback,
result, result,
) )
continue continue
# The result is a (str, Optional[callback]) tuple so return the successful result # The result is a (str, callback | None) tuple so return the successful result
return result return result
# If this point has been reached then none of the callbacks successfully authenticated # If this point has been reached then none of the callbacks successfully authenticated
@ -2353,7 +2340,7 @@ class PasswordAuthProvider:
return None return None
async def on_logged_out( async def on_logged_out(
self, user_id: str, device_id: Optional[str], access_token: str self, user_id: str, device_id: str | None, access_token: str
) -> None: ) -> None:
# call all of the on_logged_out callbacks # call all of the on_logged_out callbacks
for callback in self.on_logged_out_callbacks: for callback in self.on_logged_out_callbacks:
@ -2367,7 +2354,7 @@ class PasswordAuthProvider:
self, self,
uia_results: JsonDict, uia_results: JsonDict,
params: JsonDict, params: JsonDict,
) -> Optional[str]: ) -> str | None:
"""Defines the username to use when registering the user, using the credentials """Defines the username to use when registering the user, using the credentials
and parameters provided during the UIA flow. and parameters provided during the UIA flow.
@ -2412,7 +2399,7 @@ class PasswordAuthProvider:
self, self,
uia_results: JsonDict, uia_results: JsonDict,
params: JsonDict, params: JsonDict,
) -> Optional[str]: ) -> str | None:
"""Defines the display name to use when registering the user, using the """Defines the display name to use when registering the user, using the
credentials and parameters provided during the UIA flow. credentials and parameters provided during the UIA flow.

Some files were not shown because too many files have changed in this diff Show More