Refactor background process metrics to be homeserver-scoped (#18670)

Part of https://github.com/element-hq/synapse/issues/18592

Separated out of https://github.com/element-hq/synapse/pull/18656
because it's a bigger, unique piece of the refactor


### Testing strategy

 1. Add the `metrics` listener in your `homeserver.yaml`
    ```yaml
    listeners:
      # This is just showing how to configure metrics either way
      #
      # `http` `metrics` resource
      - port: 9322
        type: http
        bind_addresses: ['127.0.0.1']
        resources:
          - names: [metrics]
            compress: false
      # `metrics` listener
      - port: 9323
        type: metrics
        bind_addresses: ['127.0.0.1']
    ```
1. Start the homeserver: `poetry run synapse_homeserver --config-path
homeserver.yaml`
1. Fetch `http://localhost:9322/_synapse/metrics` and/or
`http://localhost:9323/metrics`
1. Observe response includes the background processs metrics
(`synapse_background_process_start_count`,
`synapse_background_process_db_txn_count_total`, etc) with the
`server_name` label
This commit is contained in:
Eric Eastwood 2025-07-23 13:28:17 -05:00 committed by GitHub
parent 8fb9c105c9
commit b7e7f537f1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
71 changed files with 906 additions and 406 deletions

1
changelog.d/18670.misc Normal file
View File

@ -0,0 +1 @@
Refactor background process metrics to be homeserver-scoped.

View File

@ -53,6 +53,7 @@ class MockHomeserver(HomeServer):
def run_background_updates(hs: HomeServer) -> None:
server_name = hs.hostname
main = hs.get_datastores().main
state = hs.get_datastores().state
@ -66,7 +67,11 @@ def run_background_updates(hs: HomeServer) -> None:
def run() -> None:
# Apply all background updates on the database.
defer.ensureDeferred(
run_as_background_process("background_updates", run_background_updates)
run_as_background_process(
"background_updates",
server_name,
run_background_updates,
)
)
reactor.callWhenRunning(run)

View File

@ -75,7 +75,7 @@ from synapse.http.site import SynapseSite
from synapse.logging.context import PreserveLoggingContext
from synapse.logging.opentracing import init_tracer
from synapse.metrics import install_gc_manager, register_threadpool
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.metrics.jemalloc import setup_jemalloc_stats
from synapse.module_api.callbacks.spamchecker_callbacks import load_legacy_spam_checkers
from synapse.module_api.callbacks.third_party_event_rules_callbacks import (
@ -512,6 +512,7 @@ async def start(hs: "HomeServer") -> None:
Args:
hs: homeserver instance
"""
server_name = hs.hostname
reactor = hs.get_reactor()
# We want to use a separate thread pool for the resolver so that large
@ -530,16 +531,24 @@ async def start(hs: "HomeServer") -> None:
# Set up the SIGHUP machinery.
if hasattr(signal, "SIGHUP"):
@wrap_as_background_process("sighup")
async def handle_sighup(*args: Any, **kwargs: Any) -> None:
# Tell systemd our state, if we're using it. This will silently fail if
# we're not using systemd.
sdnotify(b"RELOADING=1")
def handle_sighup(*args: Any, **kwargs: Any) -> "defer.Deferred[None]":
async def _handle_sighup(*args: Any, **kwargs: Any) -> None:
# Tell systemd our state, if we're using it. This will silently fail if
# we're not using systemd.
sdnotify(b"RELOADING=1")
for i, args, kwargs in _sighup_callbacks:
i(*args, **kwargs)
for i, args, kwargs in _sighup_callbacks:
i(*args, **kwargs)
sdnotify(b"READY=1")
sdnotify(b"READY=1")
return run_as_background_process(
"sighup",
server_name,
_handle_sighup,
*args,
**kwargs,
)
# We defer running the sighup handlers until next reactor tick. This
# is so that we're in a sane state, e.g. flushing the logs may fail

View File

@ -26,7 +26,11 @@ from typing import TYPE_CHECKING, List, Mapping, Sized, Tuple
from prometheus_client import Gauge
from synapse.metrics.background_process_metrics import wrap_as_background_process
from twisted.internet import defer
from synapse.metrics.background_process_metrics import (
run_as_background_process,
)
from synapse.types import JsonDict
from synapse.util.constants import ONE_HOUR_SECONDS, ONE_MINUTE_SECONDS
@ -66,125 +70,136 @@ registered_reserved_users_mau_gauge = Gauge(
)
@wrap_as_background_process("phone_stats_home")
async def phone_stats_home(
def phone_stats_home(
hs: "HomeServer",
stats: JsonDict,
stats_process: List[Tuple[int, "resource.struct_rusage"]] = _stats_process,
) -> None:
"""Collect usage statistics and send them to the configured endpoint.
) -> "defer.Deferred[None]":
server_name = hs.hostname
Args:
hs: the HomeServer object to use for gathering usage data.
stats: the dict in which to store the statistics sent to the configured
endpoint. Mostly used in tests to figure out the data that is supposed to
be sent.
stats_process: statistics about resource usage of the process.
"""
async def _phone_stats_home(
hs: "HomeServer",
stats: JsonDict,
stats_process: List[Tuple[int, "resource.struct_rusage"]] = _stats_process,
) -> None:
"""Collect usage statistics and send them to the configured endpoint.
logger.info("Gathering stats for reporting")
now = int(hs.get_clock().time())
# Ensure the homeserver has started.
assert hs.start_time is not None
uptime = int(now - hs.start_time)
if uptime < 0:
uptime = 0
Args:
hs: the HomeServer object to use for gathering usage data.
stats: the dict in which to store the statistics sent to the configured
endpoint. Mostly used in tests to figure out the data that is supposed to
be sent.
stats_process: statistics about resource usage of the process.
"""
#
# Performance statistics. Keep this early in the function to maintain reliability of `test_performance_100` test.
#
old = stats_process[0]
new = (now, resource.getrusage(resource.RUSAGE_SELF))
stats_process[0] = new
logger.info("Gathering stats for reporting")
now = int(hs.get_clock().time())
# Ensure the homeserver has started.
assert hs.start_time is not None
uptime = int(now - hs.start_time)
if uptime < 0:
uptime = 0
# Get RSS in bytes
stats["memory_rss"] = new[1].ru_maxrss
#
# Performance statistics. Keep this early in the function to maintain reliability of `test_performance_100` test.
#
old = stats_process[0]
new = (now, resource.getrusage(resource.RUSAGE_SELF))
stats_process[0] = new
# Get CPU time in % of a single core, not % of all cores
used_cpu_time = (new[1].ru_utime + new[1].ru_stime) - (
old[1].ru_utime + old[1].ru_stime
)
if used_cpu_time == 0 or new[0] == old[0]:
stats["cpu_average"] = 0
else:
stats["cpu_average"] = math.floor(used_cpu_time / (new[0] - old[0]) * 100)
# Get RSS in bytes
stats["memory_rss"] = new[1].ru_maxrss
#
# General statistics
#
store = hs.get_datastores().main
common_metrics = await hs.get_common_usage_metrics_manager().get_metrics()
stats["homeserver"] = hs.config.server.server_name
stats["server_context"] = hs.config.server.server_context
stats["timestamp"] = now
stats["uptime_seconds"] = uptime
version = sys.version_info
stats["python_version"] = "{}.{}.{}".format(
version.major, version.minor, version.micro
)
stats["total_users"] = await store.count_all_users()
total_nonbridged_users = await store.count_nonbridged_users()
stats["total_nonbridged_users"] = total_nonbridged_users
daily_user_type_results = await store.count_daily_user_type()
for name, count in daily_user_type_results.items():
stats["daily_user_type_" + name] = count
room_count = await store.get_room_count()
stats["total_room_count"] = room_count
stats["daily_active_users"] = common_metrics.daily_active_users
stats["monthly_active_users"] = await store.count_monthly_users()
daily_active_e2ee_rooms = await store.count_daily_active_e2ee_rooms()
stats["daily_active_e2ee_rooms"] = daily_active_e2ee_rooms
stats["daily_e2ee_messages"] = await store.count_daily_e2ee_messages()
daily_sent_e2ee_messages = await store.count_daily_sent_e2ee_messages()
stats["daily_sent_e2ee_messages"] = daily_sent_e2ee_messages
stats["daily_active_rooms"] = await store.count_daily_active_rooms()
stats["daily_messages"] = await store.count_daily_messages()
daily_sent_messages = await store.count_daily_sent_messages()
stats["daily_sent_messages"] = daily_sent_messages
r30v2_results = await store.count_r30v2_users()
for name, count in r30v2_results.items():
stats["r30v2_users_" + name] = count
stats["cache_factor"] = hs.config.caches.global_factor
stats["event_cache_size"] = hs.config.caches.event_cache_size
#
# Database version
#
# This only reports info about the *main* database.
stats["database_engine"] = store.db_pool.engine.module.__name__
stats["database_server_version"] = store.db_pool.engine.server_version
#
# Logging configuration
#
synapse_logger = logging.getLogger("synapse")
log_level = synapse_logger.getEffectiveLevel()
stats["log_level"] = logging.getLevelName(log_level)
logger.info(
"Reporting stats to %s: %s", hs.config.metrics.report_stats_endpoint, stats
)
try:
await hs.get_proxied_http_client().put_json(
hs.config.metrics.report_stats_endpoint, stats
# Get CPU time in % of a single core, not % of all cores
used_cpu_time = (new[1].ru_utime + new[1].ru_stime) - (
old[1].ru_utime + old[1].ru_stime
)
except Exception as e:
logger.warning("Error reporting stats: %s", e)
if used_cpu_time == 0 or new[0] == old[0]:
stats["cpu_average"] = 0
else:
stats["cpu_average"] = math.floor(used_cpu_time / (new[0] - old[0]) * 100)
#
# General statistics
#
store = hs.get_datastores().main
common_metrics = await hs.get_common_usage_metrics_manager().get_metrics()
stats["homeserver"] = hs.config.server.server_name
stats["server_context"] = hs.config.server.server_context
stats["timestamp"] = now
stats["uptime_seconds"] = uptime
version = sys.version_info
stats["python_version"] = "{}.{}.{}".format(
version.major, version.minor, version.micro
)
stats["total_users"] = await store.count_all_users()
total_nonbridged_users = await store.count_nonbridged_users()
stats["total_nonbridged_users"] = total_nonbridged_users
daily_user_type_results = await store.count_daily_user_type()
for name, count in daily_user_type_results.items():
stats["daily_user_type_" + name] = count
room_count = await store.get_room_count()
stats["total_room_count"] = room_count
stats["daily_active_users"] = common_metrics.daily_active_users
stats["monthly_active_users"] = await store.count_monthly_users()
daily_active_e2ee_rooms = await store.count_daily_active_e2ee_rooms()
stats["daily_active_e2ee_rooms"] = daily_active_e2ee_rooms
stats["daily_e2ee_messages"] = await store.count_daily_e2ee_messages()
daily_sent_e2ee_messages = await store.count_daily_sent_e2ee_messages()
stats["daily_sent_e2ee_messages"] = daily_sent_e2ee_messages
stats["daily_active_rooms"] = await store.count_daily_active_rooms()
stats["daily_messages"] = await store.count_daily_messages()
daily_sent_messages = await store.count_daily_sent_messages()
stats["daily_sent_messages"] = daily_sent_messages
r30v2_results = await store.count_r30v2_users()
for name, count in r30v2_results.items():
stats["r30v2_users_" + name] = count
stats["cache_factor"] = hs.config.caches.global_factor
stats["event_cache_size"] = hs.config.caches.event_cache_size
#
# Database version
#
# This only reports info about the *main* database.
stats["database_engine"] = store.db_pool.engine.module.__name__
stats["database_server_version"] = store.db_pool.engine.server_version
#
# Logging configuration
#
synapse_logger = logging.getLogger("synapse")
log_level = synapse_logger.getEffectiveLevel()
stats["log_level"] = logging.getLevelName(log_level)
logger.info(
"Reporting stats to %s: %s", hs.config.metrics.report_stats_endpoint, stats
)
try:
await hs.get_proxied_http_client().put_json(
hs.config.metrics.report_stats_endpoint, stats
)
except Exception as e:
logger.warning("Error reporting stats: %s", e)
return run_as_background_process(
"phone_stats_home", server_name, _phone_stats_home, hs, stats, stats_process
)
def start_phone_stats_home(hs: "HomeServer") -> None:
"""
Start the background tasks which report phone home stats.
"""
server_name = hs.hostname
clock = hs.get_clock()
stats: JsonDict = {}
@ -210,25 +225,31 @@ def start_phone_stats_home(hs: "HomeServer") -> None:
)
hs.get_datastores().main.reap_monthly_active_users()
@wrap_as_background_process("generate_monthly_active_users")
async def generate_monthly_active_users() -> None:
current_mau_count = 0
current_mau_count_by_service: Mapping[str, int] = {}
reserved_users: Sized = ()
store = hs.get_datastores().main
if hs.config.server.limit_usage_by_mau or hs.config.server.mau_stats_only:
current_mau_count = await store.get_monthly_active_count()
current_mau_count_by_service = (
await store.get_monthly_active_count_by_service()
)
reserved_users = await store.get_registered_reserved_users()
current_mau_gauge.set(float(current_mau_count))
def generate_monthly_active_users() -> "defer.Deferred[None]":
async def _generate_monthly_active_users() -> None:
current_mau_count = 0
current_mau_count_by_service: Mapping[str, int] = {}
reserved_users: Sized = ()
store = hs.get_datastores().main
if hs.config.server.limit_usage_by_mau or hs.config.server.mau_stats_only:
current_mau_count = await store.get_monthly_active_count()
current_mau_count_by_service = (
await store.get_monthly_active_count_by_service()
)
reserved_users = await store.get_registered_reserved_users()
current_mau_gauge.set(float(current_mau_count))
for app_service, count in current_mau_count_by_service.items():
current_mau_by_service_gauge.labels(app_service).set(float(count))
for app_service, count in current_mau_count_by_service.items():
current_mau_by_service_gauge.labels(app_service).set(float(count))
registered_reserved_users_mau_gauge.set(float(len(reserved_users)))
max_mau_gauge.set(float(hs.config.server.max_mau_value))
registered_reserved_users_mau_gauge.set(float(len(reserved_users)))
max_mau_gauge.set(float(hs.config.server.max_mau_value))
return run_as_background_process(
"generate_monthly_active_users",
server_name,
_generate_monthly_active_users,
)
if hs.config.server.limit_usage_by_mau or hs.config.server.mau_stats_only:
generate_monthly_active_users()

View File

@ -103,18 +103,16 @@ MAX_TO_DEVICE_MESSAGES_PER_TRANSACTION = 100
class ApplicationServiceScheduler:
"""Public facing API for this module. Does the required DI to tie the
components together. This also serves as the "event_pool", which in this
"""
Public facing API for this module. Does the required dependency injection (DI) to
tie the components together. This also serves as the "event_pool", which in this
case is a simple array.
"""
def __init__(self, hs: "HomeServer"):
self.clock = hs.get_clock()
self.txn_ctrl = _TransactionController(hs)
self.store = hs.get_datastores().main
self.as_api = hs.get_application_service_api()
self.txn_ctrl = _TransactionController(self.clock, self.store, self.as_api)
self.queuer = _ServiceQueuer(self.txn_ctrl, self.clock, hs)
self.queuer = _ServiceQueuer(self.txn_ctrl, hs)
async def start(self) -> None:
logger.info("Starting appservice scheduler")
@ -184,9 +182,7 @@ class _ServiceQueuer:
appservice at a given time.
"""
def __init__(
self, txn_ctrl: "_TransactionController", clock: Clock, hs: "HomeServer"
):
def __init__(self, txn_ctrl: "_TransactionController", hs: "HomeServer"):
# dict of {service_id: [events]}
self.queued_events: Dict[str, List[EventBase]] = {}
# dict of {service_id: [events]}
@ -199,10 +195,11 @@ class _ServiceQueuer:
# the appservices which currently have a transaction in flight
self.requests_in_flight: Set[str] = set()
self.txn_ctrl = txn_ctrl
self.clock = clock
self._msc3202_transaction_extensions_enabled: bool = (
hs.config.experimental.msc3202_transaction_extensions
)
self.server_name = hs.hostname
self.clock = hs.get_clock()
self._store = hs.get_datastores().main
def start_background_request(self, service: ApplicationService) -> None:
@ -210,7 +207,9 @@ class _ServiceQueuer:
if service.id in self.requests_in_flight:
return
run_as_background_process("as-sender", self._send_request, service)
run_as_background_process(
"as-sender", self.server_name, self._send_request, service
)
async def _send_request(self, service: ApplicationService) -> None:
# sanity-check: we shouldn't get here if this service already has a sender
@ -359,10 +358,11 @@ class _TransactionController:
(Note we have only have one of these in the homeserver.)
"""
def __init__(self, clock: Clock, store: DataStore, as_api: ApplicationServiceApi):
self.clock = clock
self.store = store
self.as_api = as_api
def __init__(self, hs: "HomeServer"):
self.server_name = hs.hostname
self.clock = hs.get_clock()
self.store = hs.get_datastores().main
self.as_api = hs.get_application_service_api()
# map from service id to recoverer instance
self.recoverers: Dict[str, "_Recoverer"] = {}
@ -446,7 +446,12 @@ class _TransactionController:
logger.info("Starting recoverer for AS ID %s", service.id)
assert service.id not in self.recoverers
recoverer = self.RECOVERER_CLASS(
self.clock, self.store, self.as_api, service, self.on_recovered
self.server_name,
self.clock,
self.store,
self.as_api,
service,
self.on_recovered,
)
self.recoverers[service.id] = recoverer
recoverer.recover()
@ -477,21 +482,24 @@ class _Recoverer:
We have one of these for each appservice which is currently considered DOWN.
Args:
clock (synapse.util.Clock):
store (synapse.storage.DataStore):
as_api (synapse.appservice.api.ApplicationServiceApi):
service (synapse.appservice.ApplicationService): the service we are managing
callback (callable[_Recoverer]): called once the service recovers.
server_name: the homeserver name (used to label metrics) (this should be `hs.hostname`).
clock:
store:
as_api:
service: the service we are managing
callback: called once the service recovers.
"""
def __init__(
self,
server_name: str,
clock: Clock,
store: DataStore,
as_api: ApplicationServiceApi,
service: ApplicationService,
callback: Callable[["_Recoverer"], Awaitable[None]],
):
self.server_name = server_name
self.clock = clock
self.store = store
self.as_api = as_api
@ -504,7 +512,11 @@ class _Recoverer:
delay = 2**self.backoff_counter
logger.info("Scheduling retries on %s in %fs", self.service.id, delay)
self.scheduled_recovery = self.clock.call_later(
delay, run_as_background_process, "as-recoverer", self.retry
delay,
run_as_background_process,
"as-recoverer",
self.server_name,
self.retry,
)
def _backoff(self) -> None:
@ -525,6 +537,7 @@ class _Recoverer:
# Run a retry, which will resechedule a recovery if it fails.
run_as_background_process(
"retry",
self.server_name,
self.retry,
)

View File

@ -152,6 +152,8 @@ class Keyring:
def __init__(
self, hs: "HomeServer", key_fetchers: "Optional[Iterable[KeyFetcher]]" = None
):
self.server_name = hs.hostname
if key_fetchers is None:
# Always fetch keys from the database.
mutable_key_fetchers: List[KeyFetcher] = [StoreKeyFetcher(hs)]
@ -169,7 +171,8 @@ class Keyring:
self._fetch_keys_queue: BatchingQueue[
_FetchKeyRequest, Dict[str, Dict[str, FetchKeyResult]]
] = BatchingQueue(
"keyring_server",
name="keyring_server",
server_name=self.server_name,
clock=hs.get_clock(),
# The method called to fetch each key
process_batch_callback=self._inner_fetch_key_requests,
@ -473,8 +476,12 @@ class Keyring:
class KeyFetcher(metaclass=abc.ABCMeta):
def __init__(self, hs: "HomeServer"):
self.server_name = hs.hostname
self._queue = BatchingQueue(
self.__class__.__name__, hs.get_clock(), self._fetch_keys
name=self.__class__.__name__,
server_name=self.server_name,
clock=hs.get_clock(),
process_batch_callback=self._fetch_keys,
)
async def get_keys(

View File

@ -34,6 +34,7 @@ class InviteAutoAccepter:
def __init__(self, config: AutoAcceptInvitesConfig, api: ModuleApi):
# Keep a reference to the Module API.
self._api = api
self.server_name = api.server_name
self._config = config
if not self._config.enabled:
@ -113,6 +114,7 @@ class InviteAutoAccepter:
# that occurs when responding to invites over federation (see https://github.com/matrix-org/synapse-auto-accept-invite/issues/12)
run_as_background_process(
"retry_make_join",
self.server_name,
self._retry_make_join,
event.state_key,
event.state_key,

View File

@ -296,6 +296,7 @@ class _DestinationWakeupQueue:
Staggers waking up of per destination queues to ensure that we don't attempt
to start TLS connections with many hosts all at once, leading to pinned CPU.
"""
# The maximum duration in seconds between queuing up a destination and it
@ -303,6 +304,10 @@ class _DestinationWakeupQueue:
_MAX_TIME_IN_QUEUE = 30.0
sender: "FederationSender" = attr.ib()
server_name: str = attr.ib()
"""
Our homeserver name (used to label metrics) (`hs.hostname`).
"""
clock: Clock = attr.ib()
max_delay_s: int = attr.ib()
@ -427,7 +432,7 @@ class FederationSender(AbstractFederationSender):
1.0 / hs.config.ratelimiting.federation_rr_transactions_per_room_per_second
)
self._destination_wakeup_queue = _DestinationWakeupQueue(
self, self.clock, max_delay_s=rr_txn_interval_per_room_s
self, self.server_name, self.clock, max_delay_s=rr_txn_interval_per_room_s
)
# Regularly wake up destinations that have outstanding PDUs to be caught up
@ -435,6 +440,7 @@ class FederationSender(AbstractFederationSender):
run_as_background_process,
WAKEUP_RETRY_PERIOD_SEC * 1000.0,
"wake_destinations_needing_catchup",
self.server_name,
self._wake_destinations_needing_catchup,
)
@ -477,7 +483,9 @@ class FederationSender(AbstractFederationSender):
# fire off a processing loop in the background
run_as_background_process(
"process_event_queue_for_federation", self._process_event_queue_loop
"process_event_queue_for_federation",
self.server_name,
self._process_event_queue_loop,
)
async def _process_event_queue_loop(self) -> None:

View File

@ -91,7 +91,7 @@ class PerDestinationQueue:
transaction_manager: "synapse.federation.sender.TransactionManager",
destination: str,
):
self._server_name = hs.hostname
self.server_name = hs.hostname
self._clock = hs.get_clock()
self._storage_controllers = hs.get_storage_controllers()
self._store = hs.get_datastores().main
@ -311,6 +311,7 @@ class PerDestinationQueue:
run_as_background_process(
"federation_transaction_transmission_loop",
self.server_name,
self._transaction_transmission_loop,
)
@ -322,7 +323,12 @@ class PerDestinationQueue:
# This will throw if we wouldn't retry. We do this here so we fail
# quickly, but we will later check this again in the http client,
# hence why we throw the result away.
await get_retry_limiter(self._destination, self._clock, self._store)
await get_retry_limiter(
destination=self._destination,
our_server_name=self.server_name,
clock=self._clock,
store=self._store,
)
if self._catching_up:
# we potentially need to catch-up first
@ -566,7 +572,7 @@ class PerDestinationQueue:
new_pdus = await filter_events_for_server(
self._storage_controllers,
self._destination,
self._server_name,
self.server_name,
new_pdus,
redact=False,
filter_out_erased_senders=True,
@ -613,7 +619,7 @@ class PerDestinationQueue:
# Send at most limit EDUs for receipts.
for content in self._pending_receipt_edus[:limit]:
yield Edu(
origin=self._server_name,
origin=self.server_name,
destination=self._destination,
edu_type=EduTypes.RECEIPT,
content=content,
@ -639,7 +645,7 @@ class PerDestinationQueue:
)
edus = [
Edu(
origin=self._server_name,
origin=self.server_name,
destination=self._destination,
edu_type=edu_type,
content=content,
@ -666,7 +672,7 @@ class PerDestinationQueue:
edus = [
Edu(
origin=self._server_name,
origin=self.server_name,
destination=self._destination,
edu_type=EduTypes.DIRECT_TO_DEVICE,
content=content,
@ -739,7 +745,7 @@ class _TransactionQueueManager:
pending_edus.append(
Edu(
origin=self.queue._server_name,
origin=self.queue.server_name,
destination=self.queue._destination,
edu_type=EduTypes.PRESENCE,
content={"push": presence_to_add},

View File

@ -38,6 +38,9 @@ logger = logging.getLogger(__name__)
class AccountValidityHandler:
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.server_name = (
hs.hostname
) # nb must be called this for @wrap_as_background_process
self.config = hs.config
self.store = hs.get_datastores().main
self.send_email_handler = hs.get_send_email_handler()

View File

@ -73,7 +73,9 @@ events_processed_counter = Counter("synapse_handlers_appservice_events_processed
class ApplicationServicesHandler:
def __init__(self, hs: "HomeServer"):
self.server_name = hs.hostname
self.server_name = (
hs.hostname
) # nb must be called this for @wrap_as_background_process
self.store = hs.get_datastores().main
self.is_mine_id = hs.is_mine_id
self.appservice_api = hs.get_application_service_api()
@ -166,7 +168,9 @@ class ApplicationServicesHandler:
except Exception:
logger.error("Application Services Failure")
run_as_background_process("as_scheduler", start_scheduler)
run_as_background_process(
"as_scheduler", self.server_name, start_scheduler
)
self.started_scheduler = True
# Fork off pushes to these services

View File

@ -199,6 +199,7 @@ class AuthHandler:
SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000
def __init__(self, hs: "HomeServer"):
self.server_name = hs.hostname
self.store = hs.get_datastores().main
self.auth = hs.get_auth()
self.auth_blocking = hs.get_auth_blocking()
@ -247,6 +248,7 @@ class AuthHandler:
run_as_background_process,
5 * 60 * 1000,
"expire_old_sessions",
self.server_name,
self._expire_old_sessions,
)
@ -271,8 +273,6 @@ class AuthHandler:
hs.config.sso.sso_account_deactivated_template
)
self._server_name = hs.config.server.server_name
# cast to tuple for use with str.startswith
self._whitelisted_sso_clients = tuple(hs.config.sso.sso_client_whitelist)
@ -1857,7 +1857,7 @@ class AuthHandler:
html = self._sso_redirect_confirm_template.render(
display_url=display_url,
redirect_url=redirect_url,
server_name=self._server_name,
server_name=self.server_name,
new_user=new_user,
user_id=registered_user_id,
user_profile=user_profile_data,

View File

@ -39,6 +39,7 @@ class DeactivateAccountHandler:
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main
self.hs = hs
self.server_name = hs.hostname
self._auth_handler = hs.get_auth_handler()
self._device_handler = hs.get_device_handler()
self._room_member_handler = hs.get_room_member_handler()
@ -243,7 +244,9 @@ class DeactivateAccountHandler:
pending deactivation, if it isn't already running.
"""
if not self._user_parter_running:
run_as_background_process("user_parter_loop", self._user_parter_loop)
run_as_background_process(
"user_parter_loop", self.server_name, self._user_parter_loop
)
async def _user_parter_loop(self) -> None:
"""Loop that parts deactivated users from rooms"""

View File

@ -110,12 +110,13 @@ class DelayedEventsHandler:
# Can send the events in background after having awaited on marking them as processed
run_as_background_process(
"_send_events",
self.server_name,
self._send_events,
events,
)
self._initialized_from_db = run_as_background_process(
"_schedule_db_events", _schedule_db_events
"_schedule_db_events", self.server_name, _schedule_db_events
)
else:
self._repl_client = ReplicationAddedDelayedEventRestServlet.make_client(hs)
@ -140,7 +141,9 @@ class DelayedEventsHandler:
finally:
self._event_processing = False
run_as_background_process("delayed_events.notify_new_event", process)
run_as_background_process(
"delayed_events.notify_new_event", self.server_name, process
)
async def _unsafe_process_new_event(self) -> None:
# If self._event_pos is None then means we haven't fetched it from the DB yet
@ -450,6 +453,7 @@ class DelayedEventsHandler:
delay_sec,
run_as_background_process,
"_send_on_timeout",
self.server_name,
self._send_on_timeout,
)
else:

View File

@ -193,8 +193,9 @@ class DeviceHandler:
self.clock.looping_call(
run_as_background_process,
DELETE_STALE_DEVICES_INTERVAL_MS,
"delete_stale_devices",
self._delete_stale_devices,
desc="delete_stale_devices",
server_name=self.server_name,
func=self._delete_stale_devices,
)
async def _delete_stale_devices(self) -> None:
@ -963,6 +964,9 @@ class DeviceWriterHandler(DeviceHandler):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.server_name = (
hs.hostname
) # nb must be called this for @measure_func and @wrap_as_background_process
# We only need to poke the federation sender explicitly if its on the
# same instance. Other federation sender instances will get notified by
# `synapse.app.generic_worker.FederationSenderHandler` when it sees it
@ -1470,6 +1474,7 @@ class DeviceListUpdater(DeviceListWorkerUpdater):
self.clock.looping_call(
run_as_background_process,
30 * 1000,
server_name=self.server_name,
func=self._maybe_retry_device_resync,
desc="_maybe_retry_device_resync",
)
@ -1591,6 +1596,7 @@ class DeviceListUpdater(DeviceListWorkerUpdater):
await self.store.mark_remote_users_device_caches_as_stale([user_id])
run_as_background_process(
"_maybe_retry_device_resync",
self.server_name,
self.multi_user_device_resync,
[user_id],
False,

View File

@ -187,7 +187,9 @@ class FederationHandler:
# were shut down.
if not hs.config.worker.worker_app:
run_as_background_process(
"resume_sync_partial_state_room", self._resume_partial_state_room_sync
"resume_sync_partial_state_room",
self.server_name,
self._resume_partial_state_room_sync,
)
@trace
@ -316,6 +318,7 @@ class FederationHandler:
)
run_as_background_process(
"_maybe_backfill_inner_anyway_with_max_depth",
self.server_name,
self.maybe_backfill,
room_id=room_id,
# We use `MAX_DEPTH` so that we find all backfill points next
@ -798,7 +801,10 @@ class FederationHandler:
# have. Hence we fire off the background task, but don't wait for it.
run_as_background_process(
"handle_queued_pdus", self._handle_queued_pdus, room_queue
"handle_queued_pdus",
self.server_name,
self._handle_queued_pdus,
room_queue,
)
async def do_knock(
@ -1870,7 +1876,9 @@ class FederationHandler:
)
run_as_background_process(
desc="sync_partial_state_room", func=_sync_partial_state_room_wrapper
desc="sync_partial_state_room",
server_name=self.server_name,
func=_sync_partial_state_room_wrapper,
)
async def _sync_partial_state_room(

View File

@ -146,6 +146,7 @@ class FederationEventHandler:
"""
def __init__(self, hs: "HomeServer"):
self.server_name = hs.hostname
self._clock = hs.get_clock()
self._store = hs.get_datastores().main
self._state_store = hs.get_datastores().state
@ -170,7 +171,6 @@ class FederationEventHandler:
self._is_mine_id = hs.is_mine_id
self._is_mine_server_name = hs.is_mine_server_name
self._server_name = hs.hostname
self._instance_name = hs.get_instance_name()
self._config = hs.config
@ -249,7 +249,7 @@ class FederationEventHandler:
# Note that if we were never in the room then we would have already
# dropped the event, since we wouldn't know the room version.
is_in_room = await self._event_auth_handler.is_host_in_room(
room_id, self._server_name
room_id, self.server_name
)
if not is_in_room:
logger.info(
@ -930,6 +930,7 @@ class FederationEventHandler:
if len(events_with_failed_pull_attempts) > 0:
run_as_background_process(
"_process_new_pulled_events_with_failed_pull_attempts",
self.server_name,
_process_new_pulled_events,
events_with_failed_pull_attempts,
)
@ -1523,6 +1524,7 @@ class FederationEventHandler:
if resync:
run_as_background_process(
"resync_device_due_to_pdu",
self.server_name,
self._resync_device,
event.sender,
)

View File

@ -92,6 +92,7 @@ class MessageHandler:
"""Contains some read only APIs to get state about a room"""
def __init__(self, hs: "HomeServer"):
self.server_name = hs.hostname
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.state = hs.get_state_handler()
@ -107,7 +108,7 @@ class MessageHandler:
if not hs.config.worker.worker_app:
run_as_background_process(
"_schedule_next_expiry", self._schedule_next_expiry
"_schedule_next_expiry", self.server_name, self._schedule_next_expiry
)
async def get_room_data(
@ -439,6 +440,7 @@ class MessageHandler:
delay,
run_as_background_process,
"_expire_event",
self.server_name,
self._expire_event,
event_id,
)
@ -541,6 +543,7 @@ class EventCreationHandler:
self.clock.looping_call(
lambda: run_as_background_process(
"send_dummy_events_to_fill_extremities",
self.server_name,
self._send_dummy_events_to_fill_extremities,
),
5 * 60 * 1000,
@ -1942,6 +1945,7 @@ class EventCreationHandler:
# matters as sometimes presence code can take a while.
run_as_background_process(
"bump_presence_active_time",
self.server_name,
self._bump_active_time,
requester.user,
requester.device_id,

View File

@ -79,12 +79,12 @@ class PaginationHandler:
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.server_name = hs.hostname
self.auth = hs.get_auth()
self.store = hs.get_datastores().main
self._storage_controllers = hs.get_storage_controllers()
self._state_storage_controller = self._storage_controllers.state
self.clock = hs.get_clock()
self._server_name = hs.hostname
self._room_shutdown_handler = hs.get_room_shutdown_handler()
self._relations_handler = hs.get_relations_handler()
self._worker_locks = hs.get_worker_locks_handler()
@ -119,6 +119,7 @@ class PaginationHandler:
run_as_background_process,
job.interval,
"purge_history_for_rooms_in_range",
self.server_name,
self.purge_history_for_rooms_in_range,
job.shortest_max_lifetime,
job.longest_max_lifetime,
@ -245,6 +246,7 @@ class PaginationHandler:
# other purges in the same room.
run_as_background_process(
PURGE_HISTORY_ACTION_NAME,
self.server_name,
self.purge_history,
room_id,
token,
@ -395,7 +397,7 @@ class PaginationHandler:
write=True,
):
# first check that we have no users in this room
joined = await self.store.is_host_joined(room_id, self._server_name)
joined = await self.store.is_host_joined(room_id, self.server_name)
if joined:
if force:
logger.info(
@ -604,6 +606,7 @@ class PaginationHandler:
# for a costly federation call and processing.
run_as_background_process(
"maybe_backfill_in_the_background",
self.server_name,
self.hs.get_federation_handler().maybe_backfill,
room_id,
curr_topo,

View File

@ -484,6 +484,7 @@ class _NullContextManager(ContextManager[None]):
class WorkerPresenceHandler(BasePresenceHandler):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.server_name = hs.hostname
self._presence_writer_instance = hs.config.worker.writers.presence[0]
# Route presence EDUs to the right worker
@ -517,6 +518,7 @@ class WorkerPresenceHandler(BasePresenceHandler):
"shutdown",
run_as_background_process,
"generic_presence.on_shutdown",
self.server_name,
self._on_shutdown,
)
@ -747,7 +749,9 @@ class WorkerPresenceHandler(BasePresenceHandler):
class PresenceHandler(BasePresenceHandler):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.server_name = hs.hostname
self.server_name = (
hs.hostname
) # nb must be called this for @wrap_as_background_process
self.wheel_timer: WheelTimer[str] = WheelTimer()
self.notifier = hs.get_notifier()
@ -815,6 +819,7 @@ class PresenceHandler(BasePresenceHandler):
"shutdown",
run_as_background_process,
"presence.on_shutdown",
self.server_name,
self._on_shutdown,
)
@ -1495,7 +1500,9 @@ class PresenceHandler(BasePresenceHandler):
finally:
self._event_processing = False
run_as_background_process("presence.notify_new_event", _process_presence)
run_as_background_process(
"presence.notify_new_event", self.server_name, _process_presence
)
async def _unsafe_process(self) -> None:
# Loop round handling deltas until we're up to date

View File

@ -2164,6 +2164,7 @@ class RoomForgetterHandler(StateDeltasHandler):
super().__init__(hs)
self._hs = hs
self.server_name = hs.hostname
self._store = hs.get_datastores().main
self._storage_controllers = hs.get_storage_controllers()
self._clock = hs.get_clock()
@ -2195,7 +2196,9 @@ class RoomForgetterHandler(StateDeltasHandler):
finally:
self._is_processing = False
run_as_background_process("room_forgetter.notify_new_event", process)
run_as_background_process(
"room_forgetter.notify_new_event", self.server_name, process
)
async def _unsafe_process(self) -> None:
# If self.pos is None then means we haven't fetched it from DB

View File

@ -54,6 +54,7 @@ class StatsHandler:
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.server_name = hs.hostname
self.store = hs.get_datastores().main
self._storage_controllers = hs.get_storage_controllers()
self.state = hs.get_state_handler()
@ -89,7 +90,7 @@ class StatsHandler:
finally:
self._is_processing = False
run_as_background_process("stats.notify_new_event", process)
run_as_background_process("stats.notify_new_event", self.server_name, process)
async def _unsafe_process(self) -> None:
# If self.pos is None then means we haven't fetched it from DB

View File

@ -80,7 +80,9 @@ class FollowerTypingHandler:
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main
self._storage_controllers = hs.get_storage_controllers()
self.server_name = hs.config.server.server_name
self.server_name = (
hs.hostname
) # nb must be called this for @wrap_as_background_process
self.clock = hs.get_clock()
self.is_mine_id = hs.is_mine_id
self.is_mine_server_name = hs.is_mine_server_name
@ -143,7 +145,11 @@ class FollowerTypingHandler:
last_fed_poke = self._member_last_federation_poke.get(member, None)
if not last_fed_poke or last_fed_poke + FEDERATION_PING_INTERVAL <= now:
run_as_background_process(
"typing._push_remote", self._push_remote, member=member, typing=True
"typing._push_remote",
self.server_name,
self._push_remote,
member=member,
typing=True,
)
# Add a paranoia timer to ensure that we always have a timer for
@ -216,6 +222,7 @@ class FollowerTypingHandler:
if self.federation:
run_as_background_process(
"_send_changes_in_typing_to_remotes",
self.server_name,
self._send_changes_in_typing_to_remotes,
row.room_id,
prev_typing,
@ -378,7 +385,11 @@ class TypingWriterHandler(FollowerTypingHandler):
if self.hs.is_mine_id(member.user_id):
# Only send updates for changes to our own users.
run_as_background_process(
"typing._push_remote", self._push_remote, member, typing
"typing._push_remote",
self.server_name,
self._push_remote,
member,
typing,
)
self._push_update_local(member=member, typing=typing)

View File

@ -192,7 +192,9 @@ class UserDirectoryHandler(StateDeltasHandler):
self._is_processing = False
self._is_processing = True
run_as_background_process("user_directory.notify_new_event", process)
run_as_background_process(
"user_directory.notify_new_event", self.server_name, process
)
async def handle_local_profile_change(
self, user_id: str, profile: ProfileInfo
@ -606,7 +608,9 @@ class UserDirectoryHandler(StateDeltasHandler):
self._is_refreshing_remote_profiles = False
self._is_refreshing_remote_profiles = True
run_as_background_process("user_directory.refresh_remote_profiles", process)
run_as_background_process(
"user_directory.refresh_remote_profiles", self.server_name, process
)
async def _unsafe_refresh_remote_profiles(self) -> None:
limit = MAX_SERVERS_TO_REFRESH_PROFILES_FOR_IN_ONE_GO - len(
@ -688,7 +692,9 @@ class UserDirectoryHandler(StateDeltasHandler):
self._is_refreshing_remote_profiles_for_servers.add(server_name)
run_as_background_process(
"user_directory.refresh_remote_profiles_for_remote_server", process
"user_directory.refresh_remote_profiles_for_remote_server",
self.server_name,
process,
)
async def _unsafe_refresh_remote_profiles_for_remote_server(

View File

@ -66,6 +66,9 @@ class WorkerLocksHandler:
"""
def __init__(self, hs: "HomeServer") -> None:
self.server_name = (
hs.hostname
) # nb must be called this for @wrap_as_background_process
self._reactor = hs.get_reactor()
self._store = hs.get_datastores().main
self._clock = hs.get_clock()

View File

@ -620,9 +620,10 @@ class MatrixFederationHttpClient:
raise FederationDeniedError(request.destination)
limiter = await synapse.util.retryutils.get_retry_limiter(
request.destination,
self.clock,
self._store,
destination=request.destination,
our_server_name=self.server_name,
clock=self.clock,
store=self._store,
backoff_on_404=backoff_on_404,
ignore_backoff=ignore_backoff,
notifier=self.hs.get_notifier(),

View File

@ -186,12 +186,16 @@ class MediaRepository:
def _start_update_recently_accessed(self) -> Deferred:
return run_as_background_process(
"update_recently_accessed_media", self._update_recently_accessed
"update_recently_accessed_media",
self.server_name,
self._update_recently_accessed,
)
def _start_apply_media_retention_rules(self) -> Deferred:
return run_as_background_process(
"apply_media_retention_rules", self._apply_media_retention_rules
"apply_media_retention_rules",
self.server_name,
self._apply_media_retention_rules,
)
async def _update_recently_accessed(self) -> None:

View File

@ -740,7 +740,7 @@ class UrlPreviewer:
def _start_expire_url_cache_data(self) -> Deferred:
return run_as_background_process(
"expire_url_cache_data", self._expire_url_cache_data
"expire_url_cache_data", self.server_name, self._expire_url_cache_data
)
async def _expire_url_cache_data(self) -> None:

View File

@ -31,6 +31,7 @@ from typing import (
Dict,
Iterable,
Optional,
Protocol,
Set,
Type,
TypeVar,
@ -39,7 +40,7 @@ from typing import (
from prometheus_client import Metric
from prometheus_client.core import REGISTRY, Counter, Gauge
from typing_extensions import ParamSpec
from typing_extensions import Concatenate, ParamSpec
from twisted.internet import defer
@ -49,6 +50,7 @@ from synapse.logging.context import (
PreserveLoggingContext,
)
from synapse.logging.opentracing import SynapseTags, start_active_span
from synapse.metrics import SERVER_NAME_LABEL
from synapse.metrics._types import Collector
if TYPE_CHECKING:
@ -64,13 +66,13 @@ logger = logging.getLogger(__name__)
_background_process_start_count = Counter(
"synapse_background_process_start_count",
"Number of background processes started",
["name"],
labelnames=["name", SERVER_NAME_LABEL],
)
_background_process_in_flight_count = Gauge(
"synapse_background_process_in_flight_count",
"Number of background processes in flight",
labelnames=["name"],
labelnames=["name", SERVER_NAME_LABEL],
)
# we set registry=None in all of these to stop them getting registered with
@ -80,21 +82,21 @@ _background_process_in_flight_count = Gauge(
_background_process_ru_utime = Counter(
"synapse_background_process_ru_utime_seconds",
"User CPU time used by background processes, in seconds",
["name"],
labelnames=["name", SERVER_NAME_LABEL],
registry=None,
)
_background_process_ru_stime = Counter(
"synapse_background_process_ru_stime_seconds",
"System CPU time used by background processes, in seconds",
["name"],
labelnames=["name", SERVER_NAME_LABEL],
registry=None,
)
_background_process_db_txn_count = Counter(
"synapse_background_process_db_txn_count",
"Number of database transactions done by background processes",
["name"],
labelnames=["name", SERVER_NAME_LABEL],
registry=None,
)
@ -104,14 +106,14 @@ _background_process_db_txn_duration = Counter(
"Seconds spent by background processes waiting for database "
"transactions, excluding scheduling time"
),
["name"],
labelnames=["name", SERVER_NAME_LABEL],
registry=None,
)
_background_process_db_sched_duration = Counter(
"synapse_background_process_db_sched_duration_seconds",
"Seconds spent by background processes waiting for database connections",
["name"],
labelnames=["name", SERVER_NAME_LABEL],
registry=None,
)
@ -169,8 +171,9 @@ REGISTRY.register(_Collector())
class _BackgroundProcess:
def __init__(self, desc: str, ctx: LoggingContext):
def __init__(self, *, desc: str, server_name: str, ctx: LoggingContext):
self.desc = desc
self.server_name = server_name
self._context = ctx
self._reported_stats: Optional[ContextResourceUsage] = None
@ -185,15 +188,21 @@ class _BackgroundProcess:
# For unknown reasons, the difference in times can be negative. See comment in
# synapse.http.request_metrics.RequestMetrics.update_metrics.
_background_process_ru_utime.labels(self.desc).inc(max(diff.ru_utime, 0))
_background_process_ru_stime.labels(self.desc).inc(max(diff.ru_stime, 0))
_background_process_db_txn_count.labels(self.desc).inc(diff.db_txn_count)
_background_process_db_txn_duration.labels(self.desc).inc(
diff.db_txn_duration_sec
)
_background_process_db_sched_duration.labels(self.desc).inc(
diff.db_sched_duration_sec
)
_background_process_ru_utime.labels(
name=self.desc, **{SERVER_NAME_LABEL: self.server_name}
).inc(max(diff.ru_utime, 0))
_background_process_ru_stime.labels(
name=self.desc, **{SERVER_NAME_LABEL: self.server_name}
).inc(max(diff.ru_stime, 0))
_background_process_db_txn_count.labels(
name=self.desc, **{SERVER_NAME_LABEL: self.server_name}
).inc(diff.db_txn_count)
_background_process_db_txn_duration.labels(
name=self.desc, **{SERVER_NAME_LABEL: self.server_name}
).inc(diff.db_txn_duration_sec)
_background_process_db_sched_duration.labels(
name=self.desc, **{SERVER_NAME_LABEL: self.server_name}
).inc(diff.db_sched_duration_sec)
R = TypeVar("R")
@ -201,6 +210,7 @@ R = TypeVar("R")
def run_as_background_process(
desc: "LiteralString",
server_name: str,
func: Callable[..., Awaitable[Optional[R]]],
*args: Any,
bg_start_span: bool = True,
@ -218,6 +228,8 @@ def run_as_background_process(
Args:
desc: a description for this background process type
server_name: The homeserver name that this background process is being run for
(this should be `hs.hostname`).
func: a function, which may return a Deferred or a coroutine
bg_start_span: Whether to start an opentracing span. Defaults to True.
Should only be disabled for processes that will not log to or tag
@ -236,10 +248,16 @@ def run_as_background_process(
count = _background_process_counts.get(desc, 0)
_background_process_counts[desc] = count + 1
_background_process_start_count.labels(desc).inc()
_background_process_in_flight_count.labels(desc).inc()
_background_process_start_count.labels(
name=desc, **{SERVER_NAME_LABEL: server_name}
).inc()
_background_process_in_flight_count.labels(
name=desc, **{SERVER_NAME_LABEL: server_name}
).inc()
with BackgroundProcessLoggingContext(desc, count) as context:
with BackgroundProcessLoggingContext(
name=desc, server_name=server_name, instance_id=count
) as context:
try:
if bg_start_span:
ctx = start_active_span(
@ -256,7 +274,9 @@ def run_as_background_process(
)
return None
finally:
_background_process_in_flight_count.labels(desc).dec()
_background_process_in_flight_count.labels(
name=desc, **{SERVER_NAME_LABEL: server_name}
).dec()
with PreserveLoggingContext():
# Note that we return a Deferred here so that it can be used in a
@ -267,6 +287,14 @@ def run_as_background_process(
P = ParamSpec("P")
class HasServerName(Protocol):
server_name: str
"""
The homeserver name that this cache is associated with (used to label the metric)
(`hs.hostname`).
"""
def wrap_as_background_process(
desc: "LiteralString",
) -> Callable[
@ -292,22 +320,37 @@ def wrap_as_background_process(
multiple places.
"""
def wrap_as_background_process_inner(
func: Callable[P, Awaitable[Optional[R]]],
def wrapper(
func: Callable[Concatenate[HasServerName, P], Awaitable[Optional[R]]],
) -> Callable[P, "defer.Deferred[Optional[R]]"]:
@wraps(func)
def wrap_as_background_process_inner_2(
*args: P.args, **kwargs: P.kwargs
def wrapped_func(
self: HasServerName, *args: P.args, **kwargs: P.kwargs
) -> "defer.Deferred[Optional[R]]":
# type-ignore: mypy is confusing kwargs with the bg_start_span kwarg.
# Argument 4 to "run_as_background_process" has incompatible type
# "**P.kwargs"; expected "bool"
# See https://github.com/python/mypy/issues/8862
return run_as_background_process(desc, func, *args, **kwargs) # type: ignore[arg-type]
assert self.server_name is not None, (
"The `server_name` attribute must be set on the object where `@wrap_as_background_process` decorator is used."
)
return wrap_as_background_process_inner_2
return run_as_background_process(
desc,
self.server_name,
func,
self,
*args,
# type-ignore: mypy is confusing kwargs with the bg_start_span kwarg.
# Argument 4 to "run_as_background_process" has incompatible type
# "**P.kwargs"; expected "bool"
# See https://github.com/python/mypy/issues/8862
**kwargs, # type: ignore[arg-type]
)
return wrap_as_background_process_inner
# There are some shenanigans here, because we're decorating a method but
# explicitly making use of the `self` parameter. The key thing here is that the
# return type within the return type for `measure_func` itself describes how the
# decorated function will be called.
return wrapped_func # type: ignore[return-value]
return wrapper # type: ignore[return-value]
class BackgroundProcessLoggingContext(LoggingContext):
@ -317,13 +360,20 @@ class BackgroundProcessLoggingContext(LoggingContext):
__slots__ = ["_proc"]
def __init__(self, name: str, instance_id: Optional[Union[int, str]] = None):
def __init__(
self,
*,
name: str,
server_name: str,
instance_id: Optional[Union[int, str]] = None,
):
"""
Args:
name: The name of the background process. Each distinct `name` gets a
separate prometheus time series.
server_name: The homeserver name that this background process is being run for
(this should be `hs.hostname`).
instance_id: an identifer to add to `name` to distinguish this instance of
the named background process in the logs. If this is `None`, one is
made up based on id(self).
@ -331,7 +381,9 @@ class BackgroundProcessLoggingContext(LoggingContext):
if instance_id is None:
instance_id = id(self)
super().__init__("%s-%s" % (name, instance_id))
self._proc: Optional[_BackgroundProcess] = _BackgroundProcess(name, self)
self._proc: Optional[_BackgroundProcess] = _BackgroundProcess(
desc=name, server_name=server_name, ctx=self
)
def start(self, rusage: "Optional[resource.struct_rusage]") -> None:
"""Log context has started running (again)."""

View File

@ -47,6 +47,7 @@ class CommonUsageMetricsManager:
"""Collects common usage metrics."""
def __init__(self, hs: "HomeServer") -> None:
self.server_name = hs.hostname
self._store = hs.get_datastores().main
self._clock = hs.get_clock()
@ -62,12 +63,15 @@ class CommonUsageMetricsManager:
async def setup(self) -> None:
"""Keep the gauges for common usage metrics up to date."""
run_as_background_process(
desc="common_usage_metrics_update_gauges", func=self._update_gauges
desc="common_usage_metrics_update_gauges",
server_name=self.server_name,
func=self._update_gauges,
)
self._clock.looping_call(
run_as_background_process,
5 * 60 * 1000,
desc="common_usage_metrics_update_gauges",
server_name=self.server_name,
func=self._update_gauges,
)

View File

@ -1326,6 +1326,7 @@ class ModuleApi:
run_as_background_process,
msec,
desc,
self.server_name,
lambda: maybe_awaitable(f(*args, **kwargs)),
)
else:
@ -1383,6 +1384,7 @@ class ModuleApi:
msec * 0.001,
run_as_background_process,
desc,
self.server_name,
lambda: maybe_awaitable(f(*args, **kwargs)),
)

View File

@ -68,6 +68,7 @@ class EmailPusher(Pusher):
super().__init__(hs, pusher_config)
self.mailer = mailer
self.server_name = hs.hostname
self.store = self.hs.get_datastores().main
self.email = pusher_config.pushkey
self.timed_call: Optional[IDelayedCall] = None
@ -117,7 +118,7 @@ class EmailPusher(Pusher):
if self._is_processing:
return
run_as_background_process("emailpush.process", self._process)
run_as_background_process("emailpush.process", self.server_name, self._process)
def _pause_processing(self) -> None:
"""Used by tests to temporarily pause processing of events.

View File

@ -106,6 +106,7 @@ class HttpPusher(Pusher):
def __init__(self, hs: "HomeServer", pusher_config: PusherConfig):
super().__init__(hs, pusher_config)
self.server_name = hs.hostname
self._storage_controllers = self.hs.get_storage_controllers()
self.app_display_name = pusher_config.app_display_name
self.device_display_name = pusher_config.device_display_name
@ -176,7 +177,9 @@ class HttpPusher(Pusher):
# We could check the receipts are actually m.read receipts here,
# but currently that's the only type of receipt anyway...
run_as_background_process("http_pusher.on_new_receipts", self._update_badge)
run_as_background_process(
"http_pusher.on_new_receipts", self.server_name, self._update_badge
)
async def _update_badge(self) -> None:
# XXX as per https://github.com/matrix-org/matrix-doc/issues/2627, this seems
@ -211,7 +214,7 @@ class HttpPusher(Pusher):
if self.failing_since and self.timed_call and self.timed_call.active():
return
run_as_background_process("httppush.process", self._process)
run_as_background_process("httppush.process", self.server_name, self._process)
async def _process(self) -> None:
# we should never get here if we are already processing

View File

@ -65,6 +65,9 @@ class PusherPool:
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.server_name = (
hs.hostname
) # nb must be called this for @wrap_as_background_process
self.pusher_factory = PusherFactory(hs)
self.store = self.hs.get_datastores().main
self.clock = self.hs.get_clock()
@ -99,7 +102,9 @@ class PusherPool:
if not self._should_start_pushers:
logger.info("Not starting pushers because they are disabled in the config")
return
run_as_background_process("start_pushers", self._start_pushers)
run_as_background_process(
"start_pushers", self.server_name, self._start_pushers
)
async def add_or_update_pusher(
self,

View File

@ -413,6 +413,7 @@ class FederationSenderHandler:
def __init__(self, hs: "HomeServer"):
assert hs.should_send_federation()
self.server_name = hs.hostname
self.store = hs.get_datastores().main
self._is_mine_id = hs.is_mine_id
self._hs = hs
@ -503,7 +504,9 @@ class FederationSenderHandler:
# no need to queue up another task.
return
run_as_background_process("_save_and_send_ack", self._save_and_send_ack)
run_as_background_process(
"_save_and_send_ack", self.server_name, self._save_and_send_ack
)
async def _save_and_send_ack(self) -> None:
"""Save the current federation position in the database and send an ACK

View File

@ -106,6 +106,7 @@ class ReplicationCommandHandler:
"""
def __init__(self, hs: "HomeServer"):
self.server_name = hs.hostname
self._replication_data_handler = hs.get_replication_data_handler()
self._presence_handler = hs.get_presence_handler()
self._store = hs.get_datastores().main
@ -340,7 +341,10 @@ class ReplicationCommandHandler:
# fire off a background process to start processing the queue.
run_as_background_process(
"process-replication-data", self._unsafe_process_queue, stream_name
"process-replication-data",
self.server_name,
self._unsafe_process_queue,
stream_name,
)
async def _unsafe_process_queue(self, stream_name: str) -> None:

View File

@ -39,7 +39,7 @@ from twisted.protocols.basic import LineOnlyReceiver
from twisted.python.failure import Failure
from synapse.logging.context import PreserveLoggingContext
from synapse.metrics import LaterGauge
from synapse.metrics import SERVER_NAME_LABEL, LaterGauge
from synapse.metrics.background_process_metrics import (
BackgroundProcessLoggingContext,
run_as_background_process,
@ -64,19 +64,21 @@ if TYPE_CHECKING:
connection_close_counter = Counter(
"synapse_replication_tcp_protocol_close_reason", "", ["reason_type"]
"synapse_replication_tcp_protocol_close_reason",
"",
labelnames=["reason_type", SERVER_NAME_LABEL],
)
tcp_inbound_commands_counter = Counter(
"synapse_replication_tcp_protocol_inbound_commands",
"Number of commands received from replication, by command and name of process connected to",
["command", "name"],
labelnames=["command", "name", SERVER_NAME_LABEL],
)
tcp_outbound_commands_counter = Counter(
"synapse_replication_tcp_protocol_outbound_commands",
"Number of commands sent to replication, by command and name of process connected to",
["command", "name"],
labelnames=["command", "name", SERVER_NAME_LABEL],
)
# A list of all connected protocols. This allows us to send metrics about the
@ -137,7 +139,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
max_line_buffer = 10000
def __init__(self, clock: Clock, handler: "ReplicationCommandHandler"):
def __init__(
self, server_name: str, clock: Clock, handler: "ReplicationCommandHandler"
):
self.server_name = server_name
self.clock = clock
self.command_handler = handler
@ -166,7 +171,9 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
# capture the sentinel context as its containing context and won't prevent
# GC of / unintentionally reactivate what would be the current context.
self._logging_context = BackgroundProcessLoggingContext(
"replication-conn", self.conn_id
name="replication-conn",
server_name=self.server_name,
instance_id=self.conn_id,
)
def connectionMade(self) -> None:
@ -244,7 +251,11 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
self.last_received_command = self.clock.time_msec()
tcp_inbound_commands_counter.labels(cmd.NAME, self.name).inc()
tcp_inbound_commands_counter.labels(
command=cmd.NAME,
name=self.name,
**{SERVER_NAME_LABEL: self.server_name},
).inc()
self.handle_command(cmd)
@ -280,7 +291,9 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
if isawaitable(res):
run_as_background_process(
"replication-" + cmd.get_logcontext_id(), lambda: res
"replication-" + cmd.get_logcontext_id(),
self.server_name,
lambda: res,
)
handled = True
@ -318,7 +331,11 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
self._queue_command(cmd)
return
tcp_outbound_commands_counter.labels(cmd.NAME, self.name).inc()
tcp_outbound_commands_counter.labels(
command=cmd.NAME,
name=self.name,
**{SERVER_NAME_LABEL: self.server_name},
).inc()
string = "%s %s" % (cmd.NAME, cmd.to_line())
if "\n" in string:
@ -390,9 +407,15 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
logger.info("[%s] Replication connection closed: %r", self.id(), reason)
if isinstance(reason, Failure):
assert reason.type is not None
connection_close_counter.labels(reason.type.__name__).inc()
connection_close_counter.labels(
reason_type=reason.type.__name__,
**{SERVER_NAME_LABEL: self.server_name},
).inc()
else:
connection_close_counter.labels(reason.__class__.__name__).inc() # type: ignore[unreachable]
connection_close_counter.labels( # type: ignore[unreachable]
reason_type=reason.__class__.__name__,
**{SERVER_NAME_LABEL: self.server_name},
).inc()
try:
# Remove us from list of connections to be monitored
@ -449,7 +472,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
def __init__(
self, server_name: str, clock: Clock, handler: "ReplicationCommandHandler"
):
super().__init__(clock, handler)
super().__init__(server_name, clock, handler)
self.server_name = server_name
@ -474,7 +497,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
clock: Clock,
command_handler: "ReplicationCommandHandler",
):
super().__init__(clock, command_handler)
super().__init__(server_name, clock, command_handler)
self.client_name = client_name
self.server_name = server_name

View File

@ -36,7 +36,11 @@ from twisted.internet.address import IPv4Address, IPv6Address
from twisted.internet.interfaces import IAddress, IConnector
from twisted.python.failure import Failure
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
from synapse.logging.context import (
PreserveLoggingContext,
make_deferred_yieldable,
)
from synapse.metrics import SERVER_NAME_LABEL
from synapse.metrics.background_process_metrics import (
BackgroundProcessLoggingContext,
run_as_background_process,
@ -97,6 +101,9 @@ class RedisSubscriber(SubscriberProtocol):
immediately after initialisation.
Attributes:
server_name: The homeserver name of the Synapse instance that this connection
is associated with. This is used to label metrics and should be set to
`hs.hostname`.
synapse_handler: The command handler to handle incoming commands.
synapse_stream_prefix: The *redis* stream name to subscribe to and publish
from (not anything to do with Synapse replication streams).
@ -104,6 +111,7 @@ class RedisSubscriber(SubscriberProtocol):
commands.
"""
server_name: str
synapse_handler: "ReplicationCommandHandler"
synapse_stream_prefix: str
synapse_channel_names: List[str]
@ -114,18 +122,36 @@ class RedisSubscriber(SubscriberProtocol):
# a logcontext which we use for processing incoming commands. We declare it as a
# background process so that the CPU stats get reported to prometheus.
with PreserveLoggingContext():
# thanks to `PreserveLoggingContext()`, the new logcontext is guaranteed to
# capture the sentinel context as its containing context and won't prevent
# GC of / unintentionally reactivate what would be the current context.
self._logging_context = BackgroundProcessLoggingContext(
"replication_command_handler"
)
self._logging_context: Optional[BackgroundProcessLoggingContext] = None
def _get_logging_context(self) -> BackgroundProcessLoggingContext:
"""
We lazily create the logging context so that `self.server_name` is set and
available. See `RedisDirectTcpReplicationClientFactory.buildProtocol` for more
details on why we set `self.server_name` after the fact instead of in the
constructor.
"""
assert self.server_name is not None, (
"self.server_name must be set before using _get_logging_context()"
)
if self._logging_context is None:
# a logcontext which we use for processing incoming commands. We declare it as a
# background process so that the CPU stats get reported to prometheus.
with PreserveLoggingContext():
# thanks to `PreserveLoggingContext()`, the new logcontext is guaranteed to
# capture the sentinel context as its containing context and won't prevent
# GC of / unintentionally reactivate what would be the current context.
self._logging_context = BackgroundProcessLoggingContext(
name="replication_command_handler", server_name=self.server_name
)
return self._logging_context
def connectionMade(self) -> None:
logger.info("Connected to redis")
super().connectionMade()
run_as_background_process("subscribe-replication", self._send_subscribe)
run_as_background_process(
"subscribe-replication", self.server_name, self._send_subscribe
)
async def _send_subscribe(self) -> None:
# it's important to make sure that we only send the REPLICATE command once we
@ -152,7 +178,7 @@ class RedisSubscriber(SubscriberProtocol):
def messageReceived(self, pattern: str, channel: str, message: str) -> None:
"""Received a message from redis."""
with PreserveLoggingContext(self._logging_context):
with PreserveLoggingContext(self._get_logging_context()):
self._parse_and_dispatch_message(message)
def _parse_and_dispatch_message(self, message: str) -> None:
@ -171,7 +197,11 @@ class RedisSubscriber(SubscriberProtocol):
# We use "redis" as the name here as we don't have 1:1 connections to
# remote instances.
tcp_inbound_commands_counter.labels(cmd.NAME, "redis").inc()
tcp_inbound_commands_counter.labels(
command=cmd.NAME,
name="redis",
**{SERVER_NAME_LABEL: self.server_name},
).inc()
self.handle_command(cmd)
@ -197,7 +227,7 @@ class RedisSubscriber(SubscriberProtocol):
if isawaitable(res):
run_as_background_process(
"replication-" + cmd.get_logcontext_id(), lambda: res
"replication-" + cmd.get_logcontext_id(), self.server_name, lambda: res
)
def connectionLost(self, reason: Failure) -> None: # type: ignore[override]
@ -207,7 +237,7 @@ class RedisSubscriber(SubscriberProtocol):
# mark the logging context as finished by triggering `__exit__()`
with PreserveLoggingContext():
with self._logging_context:
with self._get_logging_context():
pass
# the sentinel context is now active, which may not be correct.
# PreserveLoggingContext() will restore the correct logging context.
@ -219,7 +249,11 @@ class RedisSubscriber(SubscriberProtocol):
cmd: The command to send
"""
run_as_background_process(
"send-cmd", self._async_send_command, cmd, bg_start_span=False
"send-cmd",
self.server_name,
self._async_send_command,
cmd,
bg_start_span=False,
)
async def _async_send_command(self, cmd: Command) -> None:
@ -232,7 +266,11 @@ class RedisSubscriber(SubscriberProtocol):
# We use "redis" as the name here as we don't have 1:1 connections to
# remote instances.
tcp_outbound_commands_counter.labels(cmd.NAME, "redis").inc()
tcp_outbound_commands_counter.labels(
command=cmd.NAME,
name="redis",
**{SERVER_NAME_LABEL: self.server_name},
).inc()
channel_name = cmd.redis_channel_name(self.synapse_stream_prefix)
@ -275,6 +313,10 @@ class SynapseRedisFactory(RedisFactory):
convertNumbers=convertNumbers,
)
self.server_name = (
hs.hostname
) # nb must be called this for @wrap_as_background_process
hs.get_clock().looping_call(self._send_ping, 30 * 1000)
@wrap_as_background_process("redis_ping")
@ -350,6 +392,7 @@ class RedisDirectTcpReplicationClientFactory(SynapseRedisFactory):
password=hs.config.redis.redis_password,
)
self.server_name = hs.hostname
self.synapse_handler = hs.get_replication_command_handler()
self.synapse_stream_prefix = hs.hostname
self.synapse_channel_names = channel_names
@ -364,6 +407,7 @@ class RedisDirectTcpReplicationClientFactory(SynapseRedisFactory):
# as to do so would involve overriding `buildProtocol` entirely, however
# the base method does some other things than just instantiating the
# protocol.
p.server_name = self.server_name
p.synapse_handler = self.synapse_handler
p.synapse_outbound_redis_connection = self.synapse_outbound_redis_connection
p.synapse_stream_prefix = self.synapse_stream_prefix

View File

@ -144,7 +144,9 @@ class ReplicationStreamer:
logger.debug("Notifier poke loop already running")
return
run_as_background_process("replication_notifier", self._run_notifier_loop)
run_as_background_process(
"replication_notifier", self.server_name, self._run_notifier_loop
)
async def _run_notifier_loop(self) -> None:
self.is_looping = True

View File

@ -1221,6 +1221,7 @@ class RoomRedactEventRestServlet(TransactionRestServlet):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.server_name = hs.hostname
self.event_creation_handler = hs.get_event_creation_handler()
self.auth = hs.get_auth()
self._store = hs.get_datastores().main
@ -1305,6 +1306,7 @@ class RoomRedactEventRestServlet(TransactionRestServlet):
if with_relations:
run_as_background_process(
"redact_related_events",
self.server_name,
self._relation_handler.redact_events_related_to,
requester=requester,
event_id=event_id,

View File

@ -423,7 +423,7 @@ class HomeServer(metaclass=abc.ABCMeta):
@cache_in_self
def get_distributor(self) -> Distributor:
return Distributor()
return Distributor(server_name=self.hostname)
@cache_in_self
def get_registration_ratelimiter(self) -> Ratelimiter:

View File

@ -249,6 +249,7 @@ class BackgroundUpdater:
self._clock = hs.get_clock()
self.db_pool = database
self.hs = hs
self.server_name = hs.hostname
self._database_name = database.name()
@ -395,7 +396,10 @@ class BackgroundUpdater:
self._all_done = False
sleep = self.sleep_enabled
run_as_background_process(
"background_updates", self.run_background_updates, sleep
"background_updates",
self.server_name,
self.run_background_updates,
sleep,
)
async def run_background_updates(self, sleep: bool) -> None:

View File

@ -185,6 +185,7 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
def __init__(
self,
server_name: str,
per_item_callback: Callable[
[str, _EventPersistQueueTask],
Awaitable[_PersistResult],
@ -195,6 +196,7 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
The per_item_callback will be called for each item added via add_to_queue,
and its result will be returned via the Deferreds returned from add_to_queue.
"""
self.server_name = server_name
self._event_persist_queues: Dict[str, Deque[_EventPersistQueueItem]] = {}
self._currently_persisting_rooms: Set[str] = set()
self._per_item_callback = per_item_callback
@ -299,7 +301,7 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
self._currently_persisting_rooms.discard(room_id)
# set handle_queue_loop off in the background
run_as_background_process("persist_events", handle_queue_loop)
run_as_background_process("persist_events", self.server_name, handle_queue_loop)
def _get_drainining_queue(
self, room_id: str
@ -342,7 +344,7 @@ class EventsPersistenceStorageController:
self._instance_name = hs.get_instance_name()
self.is_mine_id = hs.is_mine_id
self._event_persist_queue = _EventPeristenceQueue(
self._process_event_persist_queue_task
self.server_name, self._process_event_persist_queue_task
)
self._state_resolution_handler = hs.get_state_resolution_handler()
self._state_controller = state_controller

View File

@ -46,6 +46,9 @@ class PurgeEventsStorageController:
"""High level interface for purging rooms and event history."""
def __init__(self, hs: "HomeServer", stores: Databases):
self.server_name = (
hs.hostname
) # nb must be called this for @wrap_as_background_process
self.stores = stores
if hs.config.worker.run_background_tasks:

View File

@ -561,6 +561,7 @@ class DatabasePool:
engine: BaseDatabaseEngine,
):
self.hs = hs
self.server_name = hs.hostname
self._clock = hs.get_clock()
self._txn_limit = database_config.config.get("txn_limit", 0)
self._database_config = database_config
@ -602,6 +603,7 @@ class DatabasePool:
0.0,
run_as_background_process,
"upsert_safety_check",
self.server_name,
self._check_safe_to_upsert,
)
@ -644,6 +646,7 @@ class DatabasePool:
15.0,
run_as_background_process,
"upsert_safety_check",
self.server_name,
self._check_safe_to_upsert,
)

View File

@ -78,6 +78,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
db=database,
notifier=hs.get_replication_notifier(),
stream_name="account_data",
server_name=self.server_name,
instance_name=self._instance_name,
tables=[
("room_account_data", "instance_name", "stream_id"),

View File

@ -104,10 +104,11 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
# caches to invalidate. (This reduces the amount of writes to the DB
# that happen).
self._cache_id_gen = MultiWriterIdGenerator(
db_conn,
database,
db_conn=db_conn,
db=database,
notifier=hs.get_replication_notifier(),
stream_name="caches",
server_name=self.server_name,
instance_name=hs.get_instance_name(),
tables=[
(

View File

@ -109,6 +109,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
db=database,
notifier=hs.get_replication_notifier(),
stream_name="to_device",
server_name=self.server_name,
instance_name=self._instance_name,
tables=[
("device_inbox", "instance_name", "stream_id"),
@ -156,6 +157,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
run_as_background_process,
DEVICE_FEDERATION_INBOX_CLEANUP_INTERVAL_MS,
"_delete_old_federation_inbox_rows",
self.server_name,
self._delete_old_federation_inbox_rows,
)

View File

@ -103,6 +103,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
db=database,
notifier=hs.get_replication_notifier(),
stream_name="device_lists_stream",
server_name=self.server_name,
instance_name=self._instance_name,
tables=[
("device_lists_stream", "instance_name", "stream_id"),

View File

@ -125,6 +125,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
db=database,
notifier=hs.get_replication_notifier(),
stream_name="e2e_cross_signing_keys",
server_name=self.server_name,
instance_name=self._instance_name,
tables=[
("e2e_cross_signing_keys", "instance_name", "stream_id"),

View File

@ -235,6 +235,7 @@ class EventsWorkerStore(SQLBaseStore):
db=database,
notifier=hs.get_replication_notifier(),
stream_name="events",
server_name=self.server_name,
instance_name=hs.get_instance_name(),
tables=[
("events", "instance_name", "stream_ordering"),
@ -249,6 +250,7 @@ class EventsWorkerStore(SQLBaseStore):
db=database,
notifier=hs.get_replication_notifier(),
stream_name="backfill",
server_name=self.server_name,
instance_name=hs.get_instance_name(),
tables=[
("events", "instance_name", "stream_ordering"),
@ -334,6 +336,7 @@ class EventsWorkerStore(SQLBaseStore):
db=database,
notifier=hs.get_replication_notifier(),
stream_name="un_partial_stated_event_stream",
server_name=self.server_name,
instance_name=hs.get_instance_name(),
tables=[("un_partial_stated_event_stream", "instance_name", "stream_id")],
sequence_name="un_partial_stated_event_stream_sequence",
@ -1138,7 +1141,9 @@ class EventsWorkerStore(SQLBaseStore):
should_start = False
if should_start:
run_as_background_process("fetch_events", self._fetch_thread)
run_as_background_process(
"fetch_events", self.server_name, self._fetch_thread
)
async def _fetch_thread(self) -> None:
"""Services requests for events from `_event_fetch_list`."""

View File

@ -24,9 +24,13 @@ from types import TracebackType
from typing import TYPE_CHECKING, Collection, Optional, Set, Tuple, Type
from weakref import WeakValueDictionary
from twisted.internet import defer
from twisted.internet.task import LoopingCall
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.metrics.background_process_metrics import (
run_as_background_process,
wrap_as_background_process,
)
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import (
DatabasePool,
@ -196,6 +200,7 @@ class LockStore(SQLBaseStore):
return None
lock = Lock(
self.server_name,
self._reactor,
self._clock,
self,
@ -263,6 +268,7 @@ class LockStore(SQLBaseStore):
)
lock = Lock(
self.server_name,
self._reactor,
self._clock,
self,
@ -366,6 +372,7 @@ class Lock:
def __init__(
self,
server_name: str,
reactor: ISynapseReactor,
clock: Clock,
store: LockStore,
@ -374,6 +381,11 @@ class Lock:
lock_key: str,
token: str,
) -> None:
"""
Args:
server_name: The homeserver name (used to label metrics) (this should be `hs.hostname`).
"""
self._server_name = server_name
self._reactor = reactor
self._clock = clock
self._store = store
@ -396,6 +408,7 @@ class Lock:
self._looping_call = self._clock.looping_call(
self._renew,
_RENEWAL_INTERVAL_MS,
self._server_name,
self._store,
self._clock,
self._read_write,
@ -405,31 +418,55 @@ class Lock:
)
@staticmethod
@wrap_as_background_process("Lock._renew")
async def _renew(
def _renew(
server_name: str,
store: LockStore,
clock: Clock,
read_write: bool,
lock_name: str,
lock_key: str,
token: str,
) -> None:
) -> "defer.Deferred[None]":
"""Renew the lock.
Note: this is a static method, rather than using self.*, so that we
don't end up with a reference to `self` in the reactor, which would stop
this from being cleaned up if we dropped the context manager.
Args:
server_name: The homeserver name (used to label metrics) (this should be `hs.hostname`).
"""
table = "worker_read_write_locks" if read_write else "worker_locks"
await store.db_pool.simple_update(
table=table,
keyvalues={
"lock_name": lock_name,
"lock_key": lock_key,
"token": token,
},
updatevalues={"last_renewed_ts": clock.time_msec()},
desc="renew_lock",
async def _internal_renew(
store: LockStore,
clock: Clock,
read_write: bool,
lock_name: str,
lock_key: str,
token: str,
) -> None:
table = "worker_read_write_locks" if read_write else "worker_locks"
await store.db_pool.simple_update(
table=table,
keyvalues={
"lock_name": lock_name,
"lock_key": lock_key,
"token": token,
},
updatevalues={"last_renewed_ts": clock.time_msec()},
desc="renew_lock",
)
return run_as_background_process(
"Lock._renew",
server_name,
_internal_renew,
store,
clock,
read_write,
lock_name,
lock_key,
token,
)
async def is_still_valid(self) -> bool:

View File

@ -91,6 +91,7 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore)
db=database,
notifier=hs.get_replication_notifier(),
stream_name="presence_stream",
server_name=self.server_name,
instance_name=self._instance_name,
tables=[("presence_stream", "instance_name", "stream_id")],
sequence_name="presence_stream_sequence",

View File

@ -146,6 +146,7 @@ class PushRulesWorkerStore(
db=database,
notifier=hs.get_replication_notifier(),
stream_name="push_rules_stream",
server_name=self.server_name,
instance_name=self._instance_name,
tables=[
("push_rules_stream", "instance_name", "stream_id"),

View File

@ -88,6 +88,7 @@ class PusherWorkerStore(SQLBaseStore):
db=database,
notifier=hs.get_replication_notifier(),
stream_name="pushers",
server_name=self.server_name,
instance_name=self._instance_name,
tables=[
("pushers", "instance_name", "id"),

View File

@ -124,6 +124,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
self._instance_name = hs.get_instance_name()
# In the worker store this is an ID tracker which we overwrite in the non-worker
@ -138,6 +139,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
db_conn=db_conn,
db=database,
notifier=hs.get_replication_notifier(),
server_name=self.server_name,
stream_name="receipts",
instance_name=self._instance_name,
tables=[("receipts_linearized", "instance_name", "stream_id")],
@ -145,8 +147,6 @@ class ReceiptsWorkerStore(SQLBaseStore):
writers=hs.config.worker.writers.receipts,
)
super().__init__(database, db_conn, hs)
max_receipts_stream_id = self.get_max_receipt_stream_id()
receipts_stream_prefill, min_receipts_stream_id = self.db_pool.get_cache_dict(
db_conn,

View File

@ -160,6 +160,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
db=database,
notifier=hs.get_replication_notifier(),
stream_name="un_partial_stated_room_stream",
server_name=self.server_name,
instance_name=self._instance_name,
tables=[("un_partial_stated_room_stream", "instance_name", "stream_id")],
sequence_name="un_partial_stated_room_stream_sequence",

View File

@ -69,6 +69,7 @@ class ThreadSubscriptionsWorkerStore(CacheInvalidationWorkerStore):
db=database,
notifier=hs.get_replication_notifier(),
stream_name="thread_subscriptions",
server_name=self.server_name,
instance_name=self._instance_name,
tables=[
("thread_subscriptions", "instance_name", "stream_id"),

View File

@ -195,6 +195,8 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
db
stream_name: A name for the stream, for use in the `stream_positions`
table. (Does not need to be the same as the replication stream name)
server_name: The homeserver name of the server (used to label metrics)
(this should be `hs.hostname`).
instance_name: The name of this instance.
tables: List of tables associated with the stream. Tuple of table
name, column name that stores the writer's instance name, and
@ -210,10 +212,12 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
def __init__(
self,
*,
db_conn: LoggingDatabaseConnection,
db: DatabasePool,
notifier: "ReplicationNotifier",
stream_name: str,
server_name: str,
instance_name: str,
tables: List[Tuple[str, str, str]],
sequence_name: str,
@ -223,6 +227,7 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
self._db = db
self._notifier = notifier
self._stream_name = stream_name
self.server_name = server_name
self._instance_name = instance_name
self._positive = positive
self._writers = writers
@ -561,6 +566,7 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
txn.call_after(
run_as_background_process,
"MultiWriterIdGenerator._update_table",
self.server_name,
self._db.runInteraction,
"MultiWriterIdGenerator._update_table",
self._update_stream_positions_table_txn,
@ -597,6 +603,7 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
txn.call_after(
run_as_background_process,
"MultiWriterIdGenerator._update_table",
self.server_name,
self._db.runInteraction,
"MultiWriterIdGenerator._update_table",
self._update_stream_positions_table_txn,

View File

@ -85,6 +85,8 @@ class BatchingQueue(Generic[V, R]):
Args:
name: A name for the queue, used for logging contexts and metrics.
This must be unique, otherwise the metrics will be wrong.
server_name: The homeserver name of the server (used to label metrics)
(this should be `hs.hostname`).
clock: The clock to use to schedule work.
process_batch_callback: The callback to to be run to process a batch of
work.
@ -92,11 +94,14 @@ class BatchingQueue(Generic[V, R]):
def __init__(
self,
*,
name: str,
server_name: str,
clock: Clock,
process_batch_callback: Callable[[List[V]], Awaitable[R]],
):
self._name = name
self.server_name = server_name
self._clock = clock
# The set of keys currently being processed.
@ -135,7 +140,9 @@ class BatchingQueue(Generic[V, R]):
# If we're not currently processing the key fire off a background
# process to start processing.
if key not in self._processing_keys:
run_as_background_process(self._name, self._process_queue, key)
run_as_background_process(
self._name, self.server_name, self._process_queue, key
)
with self._number_in_flight_metric.track_inprogress():
return await make_deferred_yieldable(d)

View File

@ -99,7 +99,9 @@ class ExpiringCache(Generic[KT, VT]):
return
def f() -> "defer.Deferred[None]":
return run_as_background_process("prune_cache", self._prune_cache)
return run_as_background_process(
"prune_cache", server_name, self._prune_cache
)
self._clock.looping_call(f, self._expiry_ms / 2)

View File

@ -45,11 +45,13 @@ from typing import (
overload,
)
from twisted.internet import reactor
from twisted.internet import defer, reactor
from twisted.internet.interfaces import IReactorTime
from synapse.config import cache as cache_config
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.metrics.background_process_metrics import (
run_as_background_process,
)
from synapse.metrics.jemalloc import get_jemalloc_stats
from synapse.util import Clock, caches
from synapse.util.caches import CacheMetric, EvictionReason, register_cache
@ -118,103 +120,121 @@ USE_GLOBAL_LIST = False
GLOBAL_ROOT = ListNode["_Node"].create_root_node()
@wrap_as_background_process("LruCache._expire_old_entries")
async def _expire_old_entries(
clock: Clock, expiry_seconds: float, autotune_config: Optional[dict]
) -> None:
def _expire_old_entries(
server_name: str,
clock: Clock,
expiry_seconds: float,
autotune_config: Optional[dict],
) -> "defer.Deferred[None]":
"""Walks the global cache list to find cache entries that haven't been
accessed in the given number of seconds, or if a given memory threshold has been breached.
"""
if autotune_config:
max_cache_memory_usage = autotune_config["max_cache_memory_usage"]
target_cache_memory_usage = autotune_config["target_cache_memory_usage"]
min_cache_ttl = autotune_config["min_cache_ttl"] / 1000
now = int(clock.time())
node = GLOBAL_ROOT.prev_node
assert node is not None
async def _internal_expire_old_entries(
clock: Clock, expiry_seconds: float, autotune_config: Optional[dict]
) -> None:
if autotune_config:
max_cache_memory_usage = autotune_config["max_cache_memory_usage"]
target_cache_memory_usage = autotune_config["target_cache_memory_usage"]
min_cache_ttl = autotune_config["min_cache_ttl"] / 1000
i = 0
now = int(clock.time())
node = GLOBAL_ROOT.prev_node
assert node is not None
logger.debug("Searching for stale caches")
i = 0
evicting_due_to_memory = False
logger.debug("Searching for stale caches")
# determine if we're evicting due to memory
jemalloc_interface = get_jemalloc_stats()
if jemalloc_interface and autotune_config:
try:
jemalloc_interface.refresh_stats()
mem_usage = jemalloc_interface.get_stat("allocated")
if mem_usage > max_cache_memory_usage:
logger.info("Begin memory-based cache eviction.")
evicting_due_to_memory = True
except Exception:
logger.warning(
"Unable to read allocated memory, skipping memory-based cache eviction."
)
evicting_due_to_memory = False
while node is not GLOBAL_ROOT:
# Only the root node isn't a `_TimedListNode`.
assert isinstance(node, _TimedListNode)
# if node has not aged past expiry_seconds and we are not evicting due to memory usage, there's
# nothing to do here
if (
node.last_access_ts_secs > now - expiry_seconds
and not evicting_due_to_memory
):
break
# if entry is newer than min_cache_entry_ttl then do not evict and don't evict anything newer
if evicting_due_to_memory and now - node.last_access_ts_secs < min_cache_ttl:
break
cache_entry = node.get_cache_entry()
next_node = node.prev_node
# The node should always have a reference to a cache entry and a valid
# `prev_node`, as we only drop them when we remove the node from the
# list.
assert next_node is not None
assert cache_entry is not None
cache_entry.drop_from_cache()
# Check mem allocation periodically if we are evicting a bunch of caches
if jemalloc_interface and evicting_due_to_memory and (i + 1) % 100 == 0:
# determine if we're evicting due to memory
jemalloc_interface = get_jemalloc_stats()
if jemalloc_interface and autotune_config:
try:
jemalloc_interface.refresh_stats()
mem_usage = jemalloc_interface.get_stat("allocated")
if mem_usage < target_cache_memory_usage:
evicting_due_to_memory = False
logger.info("Stop memory-based cache eviction.")
if mem_usage > max_cache_memory_usage:
logger.info("Begin memory-based cache eviction.")
evicting_due_to_memory = True
except Exception:
logger.warning(
"Unable to read allocated memory, this may affect memory-based cache eviction."
"Unable to read allocated memory, skipping memory-based cache eviction."
)
# If we've failed to read the current memory usage then we
# should stop trying to evict based on memory usage
evicting_due_to_memory = False
# If we do lots of work at once we yield to allow other stuff to happen.
if (i + 1) % 10000 == 0:
logger.debug("Waiting during drop")
if node.last_access_ts_secs > now - expiry_seconds:
await clock.sleep(0.5)
else:
await clock.sleep(0)
logger.debug("Waking during drop")
while node is not GLOBAL_ROOT:
# Only the root node isn't a `_TimedListNode`.
assert isinstance(node, _TimedListNode)
node = next_node
# if node has not aged past expiry_seconds and we are not evicting due to memory usage, there's
# nothing to do here
if (
node.last_access_ts_secs > now - expiry_seconds
and not evicting_due_to_memory
):
break
# If we've yielded then our current node may have been evicted, so we
# need to check that its still valid.
if node.prev_node is None:
break
# if entry is newer than min_cache_entry_ttl then do not evict and don't evict anything newer
if (
evicting_due_to_memory
and now - node.last_access_ts_secs < min_cache_ttl
):
break
i += 1
cache_entry = node.get_cache_entry()
next_node = node.prev_node
logger.info("Dropped %d items from caches", i)
# The node should always have a reference to a cache entry and a valid
# `prev_node`, as we only drop them when we remove the node from the
# list.
assert next_node is not None
assert cache_entry is not None
cache_entry.drop_from_cache()
# Check mem allocation periodically if we are evicting a bunch of caches
if jemalloc_interface and evicting_due_to_memory and (i + 1) % 100 == 0:
try:
jemalloc_interface.refresh_stats()
mem_usage = jemalloc_interface.get_stat("allocated")
if mem_usage < target_cache_memory_usage:
evicting_due_to_memory = False
logger.info("Stop memory-based cache eviction.")
except Exception:
logger.warning(
"Unable to read allocated memory, this may affect memory-based cache eviction."
)
# If we've failed to read the current memory usage then we
# should stop trying to evict based on memory usage
evicting_due_to_memory = False
# If we do lots of work at once we yield to allow other stuff to happen.
if (i + 1) % 10000 == 0:
logger.debug("Waiting during drop")
if node.last_access_ts_secs > now - expiry_seconds:
await clock.sleep(0.5)
else:
await clock.sleep(0)
logger.debug("Waking during drop")
node = next_node
# If we've yielded then our current node may have been evicted, so we
# need to check that its still valid.
if node.prev_node is None:
break
i += 1
logger.info("Dropped %d items from caches", i)
return run_as_background_process(
"LruCache._expire_old_entries",
server_name,
_internal_expire_old_entries,
clock,
expiry_seconds,
autotune_config,
)
def setup_expire_lru_cache_entries(hs: "HomeServer") -> None:
@ -234,10 +254,12 @@ def setup_expire_lru_cache_entries(hs: "HomeServer") -> None:
global USE_GLOBAL_LIST
USE_GLOBAL_LIST = True
server_name = hs.hostname
clock = hs.get_clock()
clock.looping_call(
_expire_old_entries,
30 * 1000,
server_name,
clock,
expiry_time,
hs.config.caches.cache_autotuning,

View File

@ -58,7 +58,13 @@ class Distributor:
model will do for today.
"""
def __init__(self) -> None:
def __init__(self, server_name: str) -> None:
"""
Args:
server_name: The homeserver name of the server (used to label metrics)
(this should be `hs.hostname`).
"""
self.server_name = server_name
self.signals: Dict[str, Signal] = {}
self.pre_registration: Dict[str, List[Callable]] = {}
@ -91,7 +97,9 @@ class Distributor:
if name not in self.signals:
raise KeyError("%r does not have a signal named %s" % (self, name))
run_as_background_process(name, self.signals[name].fire, *args, **kwargs)
run_as_background_process(
name, self.server_name, self.signals[name].fire, *args, **kwargs
)
P = ParamSpec("P")

View File

@ -59,7 +59,9 @@ class NotRetryingDestination(Exception):
async def get_retry_limiter(
*,
destination: str,
our_server_name: str,
clock: Clock,
store: DataStore,
ignore_backoff: bool = False,
@ -74,6 +76,7 @@ async def get_retry_limiter(
Args:
destination: name of homeserver
our_server_name: Our homeserver name (used to label metrics) (`hs.hostname`)
clock: timing source
store: datastore
ignore_backoff: true to ignore the historical backoff data and
@ -82,7 +85,12 @@ async def get_retry_limiter(
Example usage:
try:
limiter = await get_retry_limiter(destination, clock, store)
limiter = await get_retry_limiter(
destination=destination,
our_server_name=self.server_name,
clock=clock,
store=store,
)
with limiter:
response = await do_request()
except NotRetryingDestination:
@ -114,11 +122,12 @@ async def get_retry_limiter(
backoff_on_failure = not ignore_backoff
return RetryDestinationLimiter(
destination,
clock,
store,
failure_ts,
retry_interval,
destination=destination,
our_server_name=our_server_name,
clock=clock,
store=store,
failure_ts=failure_ts,
retry_interval=retry_interval,
backoff_on_failure=backoff_on_failure,
**kwargs,
)
@ -151,7 +160,9 @@ async def filter_destinations_by_retry_limiter(
class RetryDestinationLimiter:
def __init__(
self,
*,
destination: str,
our_server_name: str,
clock: Clock,
store: DataStore,
failure_ts: Optional[int],
@ -169,6 +180,7 @@ class RetryDestinationLimiter:
Args:
destination
our_server_name: Our homeserver name (used to label metrics) (`hs.hostname`)
clock
store
failure_ts: when this destination started failing (in ms since
@ -184,6 +196,7 @@ class RetryDestinationLimiter:
backoff_on_all_error_codes: Whether we should back off on any
error code.
"""
self.our_server_name = our_server_name
self.clock = clock
self.store = store
self.destination = destination
@ -318,4 +331,6 @@ class RetryDestinationLimiter:
logger.exception("Failed to store destination_retry_timings")
# we deliberately do this in the background.
run_as_background_process("store_retry_timings", store_retry_timings)
run_as_background_process(
"store_retry_timings", self.our_server_name, store_retry_timings
)

View File

@ -101,6 +101,9 @@ class TaskScheduler:
def __init__(self, hs: "HomeServer"):
self._hs = hs
self.server_name = (
hs.hostname
) # nb must be called this for @wrap_as_background_process
self._store = hs.get_datastores().main
self._clock = hs.get_clock()
self._running_tasks: Set[str] = set()
@ -354,7 +357,7 @@ class TaskScheduler:
finally:
self._launching_new_tasks = False
run_as_background_process("launch_scheduled_tasks", inner)
run_as_background_process("launch_scheduled_tasks", self.server_name, inner)
@wrap_as_background_process("clean_scheduled_tasks")
async def _clean_scheduled_tasks(self) -> None:
@ -485,4 +488,4 @@ class TaskScheduler:
self._running_tasks.add(task.id)
await self.update_task(task.id, status=TaskStatus.ACTIVE)
run_as_background_process(f"task-{task.action}", wrapper)
run_as_background_process(f"task-{task.action}", self.server_name, wrapper)

View File

@ -53,11 +53,24 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
self.clock = MockClock()
self.store = Mock()
self.as_api = Mock()
self.hs = Mock(
spec_set=[
"get_datastores",
"get_clock",
"get_application_service_api",
"hostname",
]
)
self.hs.get_clock.return_value = self.clock
self.hs.get_datastores.return_value = Mock(
main=self.store,
)
self.hs.get_application_service_api.return_value = self.as_api
self.recoverer = Mock()
self.recoverer_fn = Mock(return_value=self.recoverer)
self.txnctrl = _TransactionController(
clock=cast(Clock, self.clock), store=self.store, as_api=self.as_api
)
self.txnctrl = _TransactionController(self.hs)
self.txnctrl.RECOVERER_CLASS = self.recoverer_fn
def test_single_service_up_txn_sent(self) -> None:
@ -163,6 +176,7 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
self.service = Mock()
self.callback = AsyncMock()
self.recoverer = _Recoverer(
server_name="test_server",
clock=cast(Clock, self.clock),
as_api=self.as_api,
store=self.store,

View File

@ -14,6 +14,8 @@ class TestBackgroundProcessMetrics(StdlibTestCase):
mock_logging_context = Mock(spec=LoggingContext)
mock_logging_context.get_resource_usage.return_value = usage
process = _BackgroundProcess("test process", mock_logging_context)
process = _BackgroundProcess(
desc="test process", server_name="test_server", ctx=mock_logging_context
)
# Should not raise
process.update_metrics()

View File

@ -80,10 +80,11 @@ class MultiWriterIdGeneratorBase(HomeserverTestCase):
) -> MultiWriterIdGenerator:
def _create(conn: LoggingDatabaseConnection) -> MultiWriterIdGenerator:
return MultiWriterIdGenerator(
conn,
self.db_pool,
db_conn=conn,
db=self.db_pool,
notifier=self.hs.get_replication_notifier(),
stream_name="test_stream",
server_name=self.hs.hostname,
instance_name=instance_name,
tables=[(table, "instance_name", "stream_id") for table in self.tables],
sequence_name="foobar_seq",

View File

@ -28,7 +28,7 @@ from . import unittest
class DistributorTestCase(unittest.TestCase):
def setUp(self) -> None:
self.dist = Distributor()
self.dist = Distributor(server_name="test_server")
def test_signal_dispatch(self) -> None:
self.dist.declare("alert")

View File

@ -50,7 +50,10 @@ class BatchingQueueTestCase(TestCase):
self._pending_calls: List[Tuple[List[str], defer.Deferred]] = []
self.queue: BatchingQueue[str, str] = BatchingQueue(
"test_queue", hs_clock, self._process_queue
name="test_queue",
server_name="test_server",
clock=hs_clock,
process_batch_callback=self._process_queue,
)
async def _process_queue(self, values: List[str]) -> str:

View File

@ -31,7 +31,14 @@ class RetryLimiterTestCase(HomeserverTestCase):
def test_new_destination(self) -> None:
"""A happy-path case with a new destination and a successful operation"""
store = self.hs.get_datastores().main
limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store))
limiter = self.get_success(
get_retry_limiter(
destination="test_dest",
our_server_name=self.hs.hostname,
clock=self.clock,
store=store,
)
)
# advance the clock a bit before making the request
self.pump(1)
@ -46,7 +53,14 @@ class RetryLimiterTestCase(HomeserverTestCase):
"""General test case which walks through the process of a failing request"""
store = self.hs.get_datastores().main
limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store))
limiter = self.get_success(
get_retry_limiter(
destination="test_dest",
our_server_name=self.hs.hostname,
clock=self.clock,
store=store,
)
)
min_retry_interval_ms = (
self.hs.config.federation.destination_min_retry_interval_ms
@ -72,7 +86,13 @@ class RetryLimiterTestCase(HomeserverTestCase):
# now if we try again we should get a failure
self.get_failure(
get_retry_limiter("test_dest", self.clock, store), NotRetryingDestination
get_retry_limiter(
destination="test_dest",
our_server_name=self.hs.hostname,
clock=self.clock,
store=store,
),
NotRetryingDestination,
)
#
@ -80,7 +100,14 @@ class RetryLimiterTestCase(HomeserverTestCase):
#
self.pump(min_retry_interval_ms)
limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store))
limiter = self.get_success(
get_retry_limiter(
destination="test_dest",
our_server_name=self.hs.hostname,
clock=self.clock,
store=store,
)
)
self.pump(1)
try:
@ -108,7 +135,14 @@ class RetryLimiterTestCase(HomeserverTestCase):
# one more go, with success
#
self.reactor.advance(min_retry_interval_ms * retry_multiplier * 2.0)
limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store))
limiter = self.get_success(
get_retry_limiter(
destination="test_dest",
our_server_name=self.hs.hostname,
clock=self.clock,
store=store,
)
)
self.pump(1)
with limiter:
@ -129,9 +163,10 @@ class RetryLimiterTestCase(HomeserverTestCase):
limiter = self.get_success(
get_retry_limiter(
"test_dest",
self.clock,
store,
destination="test_dest",
our_server_name=self.hs.hostname,
clock=self.clock,
store=store,
notifier=notifier,
replication_client=replication_client,
)
@ -199,7 +234,14 @@ class RetryLimiterTestCase(HomeserverTestCase):
self.hs.config.federation.destination_max_retry_interval_ms
)
self.get_success(get_retry_limiter("test_dest", self.clock, store))
self.get_success(
get_retry_limiter(
destination="test_dest",
our_server_name=self.hs.hostname,
clock=self.clock,
store=store,
)
)
self.pump(1)
failure_ts = self.clock.time_msec()
@ -216,12 +258,25 @@ class RetryLimiterTestCase(HomeserverTestCase):
# Check it fails
self.get_failure(
get_retry_limiter("test_dest", self.clock, store), NotRetryingDestination
get_retry_limiter(
destination="test_dest",
our_server_name=self.hs.hostname,
clock=self.clock,
store=store,
),
NotRetryingDestination,
)
# Get past retry_interval and we can try again, and still throw an error to continue the backoff
self.reactor.advance(destination_max_retry_interval_ms / 1000 + 1)
limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store))
limiter = self.get_success(
get_retry_limiter(
destination="test_dest",
our_server_name=self.hs.hostname,
clock=self.clock,
store=store,
)
)
self.pump(1)
try:
with limiter:
@ -239,5 +294,11 @@ class RetryLimiterTestCase(HomeserverTestCase):
# Check it fails
self.get_failure(
get_retry_limiter("test_dest", self.clock, store), NotRetryingDestination
get_retry_limiter(
destination="test_dest",
our_server_name=self.hs.hostname,
clock=self.clock,
store=store,
),
NotRetryingDestination,
)