Fix server_name in logging context for multiple Synapse instances in one process (#18868)

### Background

As part of Element's plan to support a light form of vhosting (virtual
host) (multiple instances of Synapse in the same Python process), we're
currently diving into the details and implications of running multiple
instances of Synapse in the same Python process.

"Per-tenant logging" tracked internally by
https://github.com/element-hq/synapse-small-hosts/issues/48

### Prior art

Previously, we exposed `server_name` by providing a static logging
`MetadataFilter` that injected the values:


205d9e4fc4/synapse/config/logger.py (L216)

While this can work fine for the normal case of one Synapse instance per
Python process, this configures things globally and isn't compatible
when we try to start multiple Synapse instances because each subsequent
tenant will overwrite the previous tenant.


### What does this PR do?

We remove the `MetadataFilter` and replace it by tracking the
`server_name` in the `LoggingContext` and expose it with our existing
[`LoggingContextFilter`](205d9e4fc4/synapse/logging/context.py (L584-L622))
that we already use to expose information about the `request`.

This means that the `server_name` value follows wherever we log as
expected even when we have multiple Synapse instances running in the
same process.


### A note on logcontext

Anywhere, Synapse mistakenly uses the `sentinel` logcontext to log
something, we won't know which server sent the log. We've been fixing up
`sentinel` logcontext usage as tracked by
https://github.com/element-hq/synapse/issues/18905

Any further `sentinel` logcontext usage we find in the future can be
fixed piecemeal as normal.


d2a966f922/docs/log_contexts.md (L71-L81)


### Testing strategy

1. Adjust your logging config to include `%(server_name)s` in the format
    ```yaml
    formatters:
        precise:
format: '%(asctime)s - %(server_name)s - %(name)s - %(lineno)d -
%(levelname)s - %(request)s - %(message)s'
    ```
1. Start Synapse: `poetry run synapse_homeserver --config-path
homeserver.yaml`
1. Make some requests (`curl
http://localhost:8008/_matrix/client/versions`, etc)
1. Open the homeserver logs and notice the `server_name` in the logs as
expected. `unknown_server_from_sentinel_context` is expected for the
`sentinel` logcontext (things outside of Synapse).
This commit is contained in:
Eric Eastwood 2025-09-26 17:10:48 -05:00 committed by GitHub
parent 2f2b854ac1
commit 5143f93dc9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
72 changed files with 433 additions and 315 deletions

1
changelog.d/18868.misc Normal file
View File

@ -0,0 +1 @@
Fix `server_name` in logging context for multiple Synapse instances in one process.

View File

@ -599,7 +599,7 @@ async def start(hs: "HomeServer") -> None:
hs.get_pusherpool().start()
def log_shutdown() -> None:
with LoggingContext("log_shutdown"):
with LoggingContext(name="log_shutdown", server_name=server_name):
logger.info("Shutting down...")
# Log when we start the shut down process.

View File

@ -329,7 +329,7 @@ def start(config: HomeServerConfig, args: argparse.Namespace) -> None:
# command.
async def run() -> None:
with LoggingContext(name="command"):
with LoggingContext(name="command", server_name=config.server.server_name):
await _base.start(ss)
await args.func(ss, args)
@ -342,5 +342,5 @@ def start(config: HomeServerConfig, args: argparse.Namespace) -> None:
if __name__ == "__main__":
homeserver_config, args = load_config(sys.argv[1:])
with LoggingContext(name="main"):
with LoggingContext(name="main", server_name=homeserver_config.server.server_name):
start(homeserver_config, args)

View File

@ -27,7 +27,7 @@ from synapse.util.logcontext import LoggingContext
def main() -> None:
homeserver_config = load_config(sys.argv[1:])
with LoggingContext(name="main"):
with LoggingContext(name="main", server_name=homeserver_config.server.server_name):
start(homeserver_config)

View File

@ -27,7 +27,7 @@ from synapse.util.logcontext import LoggingContext
def main() -> None:
homeserver_config = load_config(sys.argv[1:])
with LoggingContext(name="main"):
with LoggingContext(name="main", server_name=homeserver_config.server.server_name):
start(homeserver_config)

View File

@ -26,7 +26,7 @@ from synapse.util.logcontext import LoggingContext
def main() -> None:
homeserver_config = load_config(sys.argv[1:])
with LoggingContext(name="main"):
with LoggingContext(name="main", server_name=homeserver_config.server.server_name):
start(homeserver_config)

View File

@ -27,7 +27,7 @@ from synapse.util.logcontext import LoggingContext
def main() -> None:
homeserver_config = load_config(sys.argv[1:])
with LoggingContext(name="main"):
with LoggingContext(name="main", server_name=homeserver_config.server.server_name):
start(homeserver_config)

View File

@ -27,7 +27,7 @@ from synapse.util.logcontext import LoggingContext
def main() -> None:
homeserver_config = load_config(sys.argv[1:])
with LoggingContext(name="main"):
with LoggingContext(name="main", server_name=homeserver_config.server.server_name):
start(homeserver_config)

View File

@ -27,7 +27,7 @@ from synapse.util.logcontext import LoggingContext
def main() -> None:
homeserver_config = load_config(sys.argv[1:])
with LoggingContext(name="main"):
with LoggingContext(name="main", server_name=homeserver_config.server.server_name):
start(homeserver_config)

View File

@ -386,7 +386,7 @@ def start(config: HomeServerConfig) -> None:
def main() -> None:
homeserver_config = load_config(sys.argv[1:])
with LoggingContext(name="main"):
with LoggingContext(name="main", server_name=homeserver_config.server.server_name):
start(homeserver_config)

View File

@ -429,7 +429,7 @@ def run(hs: HomeServer) -> None:
def main() -> None:
homeserver_config = load_or_generate_config(sys.argv[1:])
with LoggingContext("main"):
with LoggingContext(name="main", server_name=homeserver_config.server.server_name):
# check base requirements
check_requirements()
hs = setup(homeserver_config)

View File

@ -27,7 +27,7 @@ from synapse.util.logcontext import LoggingContext
def main() -> None:
homeserver_config = load_config(sys.argv[1:])
with LoggingContext(name="main"):
with LoggingContext(name="main", server_name=homeserver_config.server.server_name):
start(homeserver_config)

View File

@ -27,7 +27,7 @@ from synapse.util.logcontext import LoggingContext
def main() -> None:
homeserver_config = load_config(sys.argv[1:])
with LoggingContext(name="main"):
with LoggingContext(name="main", server_name=homeserver_config.server.server_name):
start(homeserver_config)

View File

@ -27,7 +27,7 @@ from synapse.util.logcontext import LoggingContext
def main() -> None:
homeserver_config = load_config(sys.argv[1:])
with LoggingContext(name="main"):
with LoggingContext(name="main", server_name=homeserver_config.server.server_name):
start(homeserver_config)

View File

@ -27,7 +27,7 @@ from synapse.util.logcontext import LoggingContext
def main() -> None:
homeserver_config = load_config(sys.argv[1:])
with LoggingContext(name="main"):
with LoggingContext(name="main", server_name=homeserver_config.server.server_name):
start(homeserver_config)

View File

@ -601,7 +601,7 @@ class RootConfig:
@classmethod
def load_config_with_parser(
cls: Type[TRootConfig], parser: argparse.ArgumentParser, argv: List[str]
cls: Type[TRootConfig], parser: argparse.ArgumentParser, argv_options: List[str]
) -> Tuple[TRootConfig, argparse.Namespace]:
"""Parse the commandline and config files with the given parser
@ -611,14 +611,14 @@ class RootConfig:
Args:
parser
argv
argv_options: The options passed to Synapse. Usually `sys.argv[1:]`.
Returns:
Returns the parsed config object and the parsed argparse.Namespace
object from parser.parse_args(..)`
"""
config_args = parser.parse_args(argv)
config_args = parser.parse_args(argv_options)
config_files = find_config_files(search_paths=config_args.config_path)
obj = cls(config_files)

View File

@ -40,7 +40,6 @@ from twisted.logger import (
)
from synapse.logging.context import LoggingContextFilter
from synapse.logging.filter import MetadataFilter
from synapse.synapse_rust import reset_logging_config
from synapse.types import JsonDict
@ -213,13 +212,11 @@ def _setup_stdlib_logging(
# writes.
log_context_filter = LoggingContextFilter()
log_metadata_filter = MetadataFilter({"server_name": config.server.server_name})
old_factory = logging.getLogRecordFactory()
def factory(*args: Any, **kwargs: Any) -> logging.LogRecord:
record = old_factory(*args, **kwargs)
log_context_filter.filter(record)
log_metadata_filter.filter(record)
return record
logging.setLogRecordFactory(factory)

View File

@ -159,7 +159,7 @@ class FederationServer(FederationBase):
# with FederationHandlerRegistry.
hs.get_directory_handler()
self._server_linearizer = Linearizer("fed_server")
self._server_linearizer = Linearizer(name="fed_server", clock=hs.get_clock())
# origins that we are currently processing a transaction from.
# a dict from origin to txn id.

View File

@ -98,7 +98,7 @@ class ApplicationServicesHandler:
self.is_processing = False
self._ephemeral_events_linearizer = Linearizer(
name="appservice_ephemeral_events"
name="appservice_ephemeral_events", clock=hs.get_clock()
)
def notify_interested_services(self, max_token: RoomStreamToken) -> None:

View File

@ -1450,8 +1450,12 @@ class DeviceListUpdater(DeviceListWorkerUpdater):
self.clock = hs.get_clock() # nb must be called this for @measure_func
self.device_handler = device_handler
self._remote_edu_linearizer = Linearizer(name="remote_device_list")
self._resync_linearizer = Linearizer(name="remote_device_resync")
self._remote_edu_linearizer = Linearizer(
name="remote_device_list", clock=self.clock
)
self._resync_linearizer = Linearizer(
name="remote_device_resync", clock=self.clock
)
# user_id -> list of updates waiting to be handled.
self._pending_updates: Dict[

View File

@ -112,8 +112,7 @@ class E2eKeysHandler:
# Limit the number of in-flight requests from a single device.
self._query_devices_linearizer = Linearizer(
name="query_devices",
max_count=10,
name="query_devices", max_count=10, clock=hs.get_clock()
)
self._query_appservices_for_otks = (
@ -1765,7 +1764,9 @@ class SigningKeyEduUpdater:
assert isinstance(device_handler, DeviceWriterHandler)
self._device_handler = device_handler
self._remote_edu_linearizer = Linearizer(name="remote_signing_key")
self._remote_edu_linearizer = Linearizer(
name="remote_signing_key", clock=self.clock
)
# user_id -> list of updates waiting to be handled.
self._pending_updates: Dict[str, List[Tuple[JsonDict, JsonDict]]] = {}

View File

@ -160,7 +160,7 @@ class FederationHandler:
self._notifier = hs.get_notifier()
self._worker_locks = hs.get_worker_locks_handler()
self._room_backfill = Linearizer("room_backfill")
self._room_backfill = Linearizer(name="room_backfill", clock=self.clock)
self._third_party_event_rules = (
hs.get_module_api_callbacks().third_party_event_rules
@ -180,7 +180,8 @@ class FederationHandler:
# When the lock is held for a given room, no other concurrent code may
# partial state or un-partial state the room.
self._is_partial_state_room_linearizer = Linearizer(
name="_is_partial_state_room_linearizer"
name="_is_partial_state_room_linearizer",
clock=self.clock,
)
# if this is the main process, fire off a background process to resume

View File

@ -191,7 +191,7 @@ class FederationEventHandler:
# federation event staging area.
self.room_queues: Dict[str, List[Tuple[EventBase, str]]] = {}
self._room_pdu_linearizer = Linearizer("fed_room_pdu")
self._room_pdu_linearizer = Linearizer(name="fed_room_pdu", clock=self._clock)
async def on_receive_pdu(self, origin: str, pdu: EventBase) -> None:
"""Process a PDU received via a federation /send/ transaction

View File

@ -513,7 +513,9 @@ class EventCreationHandler:
# We limit concurrent event creation for a room to 1. This prevents state resolution
# from occurring when sending bursts of events to a local room
self.limiter = Linearizer(max_count=1, name="room_event_creation_limit")
self.limiter = Linearizer(
max_count=1, name="room_event_creation_limit", clock=self.clock
)
self._bulk_push_rule_evaluator = hs.get_bulk_push_rule_evaluator()

View File

@ -872,7 +872,9 @@ class PresenceHandler(BasePresenceHandler):
] = {}
self.external_process_last_updated_ms: Dict[str, int] = {}
self.external_sync_linearizer = Linearizer(name="external_sync_linearizer")
self.external_sync_linearizer = Linearizer(
name="external_sync_linearizer", clock=self.clock
)
if self._track_presence:
# Start a LoopingCall in 30s that fires every 5s.

View File

@ -36,7 +36,9 @@ class ReadMarkerHandler:
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main
self.account_data_handler = hs.get_account_data_handler()
self.read_marker_linearizer = Linearizer(name="read_marker")
self.read_marker_linearizer = Linearizer(
name="read_marker", clock=hs.get_clock()
)
async def received_client_read_marker(
self, room_id: str, user_id: str, event_id: str

View File

@ -114,8 +114,12 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
if self.hs.config.server.include_profile_data_on_invite:
self._membership_types_to_include_profile_data_in.add(Membership.INVITE)
self.member_linearizer: Linearizer = Linearizer(name="member")
self.member_as_limiter = Linearizer(max_count=10, name="member_as_limiter")
self.member_linearizer: Linearizer = Linearizer(
name="member", clock=hs.get_clock()
)
self.member_as_limiter = Linearizer(
max_count=10, name="member_as_limiter", clock=hs.get_clock()
)
self.clock = hs.get_clock()
self._spam_checker_module_callbacks = hs.get_module_api_callbacks().spam_checker

View File

@ -980,7 +980,10 @@ class SyncHandler:
)
if cache is None:
logger.debug("creating LruCache for %r", cache_key)
cache = LruCache(max_size=LAZY_LOADED_MEMBERS_CACHE_MAX_SIZE)
cache = LruCache(
max_size=LAZY_LOADED_MEMBERS_CACHE_MAX_SIZE,
server_name=self.server_name,
)
self.lazy_loaded_members_cache[cache_key] = cache
else:
logger.debug("found LruCache for %r", cache_key)

View File

@ -124,7 +124,7 @@ class MatrixFederationAgent:
# addresses, to prevent DNS rebinding.
reactor = BlocklistingReactorWrapper(reactor, ip_allowlist, ip_blocklist)
self._clock = Clock(reactor)
self._clock = Clock(reactor, server_name=server_name)
self._pool = HTTPConnectionPool(reactor)
self._pool.retryAutomatically = False
self._pool.maxPersistentPerHost = 5

View File

@ -107,7 +107,7 @@ class WellKnownResolver:
self.server_name = server_name
self._reactor = reactor
self._clock = Clock(reactor)
self._clock = Clock(reactor, server_name=server_name)
if well_known_cache is None:
well_known_cache = TTLCache(

View File

@ -481,7 +481,9 @@ class MatrixFederationHttpClient:
use_proxy=True,
)
self.remote_download_linearizer = Linearizer("remote_download_linearizer", 6)
self.remote_download_linearizer = Linearizer(
name="remote_download_linearizer", max_count=6, clock=self.clock
)
def wake_destination(self, destination: str) -> None:
"""Called when the remote server may have come back online."""

View File

@ -411,8 +411,19 @@ class DirectServeJsonResource(_AsyncResource):
# Clock is optional as this class is exposed to the module API.
clock: Optional[Clock] = None,
):
"""
Args:
canonical_json: TODO
extract_context: TODO
clock: This is expected to be passed in by any Synapse code.
Only optional for the Module API.
"""
if clock is None:
clock = Clock(cast(ISynapseThreadlessReactor, reactor))
clock = Clock(
cast(ISynapseThreadlessReactor, reactor),
server_name="synapse_module_running_from_unknown_server",
)
super().__init__(clock, extract_context)
self.canonical_json = canonical_json
@ -590,8 +601,17 @@ class DirectServeHtmlResource(_AsyncResource):
# Clock is optional as this class is exposed to the module API.
clock: Optional[Clock] = None,
):
"""
Args:
extract_context: TODO
clock: This is expected to be passed in by any Synapse code.
Only optional for the Module API.
"""
if clock is None:
clock = Clock(cast(ISynapseThreadlessReactor, reactor))
clock = Clock(
cast(ISynapseThreadlessReactor, reactor),
server_name="synapse_module_running_from_unknown_server",
)
super().__init__(clock, extract_context)

View File

@ -302,10 +302,15 @@ class SynapseRequest(Request):
# this is called once a Resource has been found to serve the request; in our
# case the Resource in question will normally be a JsonResource.
# create a LogContext for this request
# Create a LogContext for this request
#
# We only care about associating logs and tallying up metrics at the per-request
# level so we don't worry about setting the `parent_context`; preventing us from
# unnecessarily piling up metrics on the main process's context.
request_id = self.get_request_id()
self.logcontext = LoggingContext(
request_id,
name=request_id,
server_name=self.our_server_name,
request=ContextRequest(
request_id=request_id,
ip_address=self.get_client_ip_if_available(),

View File

@ -238,12 +238,13 @@ class _Sentinel:
we should always know which server the logs are coming from.
"""
__slots__ = ["previous_context", "finished", "request", "tag"]
__slots__ = ["previous_context", "finished", "server_name", "request", "tag"]
def __init__(self) -> None:
# Minimal set for compatibility with LoggingContext
self.previous_context = None
self.finished = False
self.server_name = "unknown_server_from_sentinel_context"
self.request = None
self.tag = None
@ -282,14 +283,19 @@ class LoggingContext:
child to the parent
Args:
name: Name for the context for logging. If this is omitted, it is
inherited from the parent context.
name: Name for the context for logging.
server_name: The name of the server this context is associated with
(`config.server.server_name` or `hs.hostname`)
parent_context (LoggingContext|None): The parent of the new context
request: Synapse Request Context object. Useful to associate all the logs
happening to a given request.
"""
__slots__ = [
"previous_context",
"name",
"server_name",
"parent_context",
"_resource_usage",
"usage_start",
@ -301,7 +307,9 @@ class LoggingContext:
def __init__(
self,
name: Optional[str] = None,
*,
name: str,
server_name: str,
parent_context: "Optional[LoggingContext]" = None,
request: Optional[ContextRequest] = None,
) -> None:
@ -314,6 +322,8 @@ class LoggingContext:
# if the context is not currently active.
self.usage_start: Optional[resource.struct_rusage] = None
self.name = name
self.server_name = server_name
self.main_thread = get_thread_id()
self.request = None
self.tag = ""
@ -325,23 +335,15 @@ class LoggingContext:
self.parent_context = parent_context
# Inherit some fields from the parent context
if self.parent_context is not None:
# we track the current request_id
# which request this corresponds to
self.request = self.parent_context.request
if request is not None:
# the request param overrides the request from the parent context
self.request = request
# if we don't have a `name`, but do have a parent context, use its name.
if self.parent_context and name is None:
name = str(self.parent_context)
if name is None:
raise ValueError(
"LoggingContext must be given either a name or a parent context"
)
self.name = name
def __str__(self) -> str:
return self.name
@ -588,7 +590,26 @@ class LoggingContextFilter(logging.Filter):
record.
"""
def __init__(self, request: str = ""):
def __init__(
self,
# `request` is here for backwards compatibility since we previously recommended
# people manually configure `LoggingContextFilter` like the following.
#
# ```yaml
# filters:
# context:
# (): synapse.logging.context.LoggingContextFilter
# request: ""
# ```
#
# TODO: Since we now configure `LoggingContextFilter` automatically since #8051
# (2020-08-11), we could consider removing this useless parameter. This would
# require people to remove their own manual configuration of
# `LoggingContextFilter` as it would cause `TypeError: Filter.__init__() got an
# unexpected keyword argument 'request'` -> `ValueError: Unable to configure
# filter 'context'`
request: str = "",
):
self._default_request = request
def filter(self, record: logging.LogRecord) -> Literal[True]:
@ -598,11 +619,13 @@ class LoggingContextFilter(logging.Filter):
"""
context = current_context()
record.request = self._default_request
record.server_name = "unknown_server_from_no_context"
# context should never be None, but if it somehow ends up being, then
# we end up in a death spiral of infinite loops, so let's check, for
# robustness' sake.
if context is not None:
record.server_name = context.server_name
# Logging is interested in the request ID. Note that for backwards
# compatibility this is stored as the "request" on the record.
record.request = str(context)
@ -728,12 +751,15 @@ def nested_logging_context(suffix: str) -> LoggingContext:
"Starting nested logging context from sentinel context: metrics will be lost"
)
parent_context = None
server_name = "unknown_server_from_sentinel_context"
else:
assert isinstance(curr_context, LoggingContext)
parent_context = curr_context
server_name = parent_context.server_name
prefix = str(curr_context)
return LoggingContext(
prefix + "-" + suffix,
name=prefix + "-" + suffix,
server_name=server_name,
parent_context=parent_context,
)
@ -1058,12 +1084,18 @@ def defer_to_threadpool(
"Calling defer_to_threadpool from sentinel context: metrics will be lost"
)
parent_context = None
server_name = "unknown_server_from_sentinel_context"
else:
assert isinstance(curr_context, LoggingContext)
parent_context = curr_context
server_name = parent_context.server_name
def g() -> R:
with LoggingContext(str(curr_context), parent_context=parent_context):
with LoggingContext(
name=str(curr_context),
server_name=server_name,
parent_context=parent_context,
):
return f(*args, **kwargs)
return make_deferred_yieldable(threads.deferToThreadPool(reactor, threadpool, g))

View File

@ -1,38 +0,0 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright 2020 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# [This file includes modifications made by New Vector Limited]
#
#
import logging
from typing import Literal
class MetadataFilter(logging.Filter):
"""Logging filter that adds constant values to each record.
Args:
metadata: Key-value pairs to add to each record.
"""
def __init__(self, metadata: dict):
self._metadata = metadata
def filter(self, record: logging.LogRecord) -> Literal[True]:
for key, value in self._metadata.items():
setattr(record, key, value)
return True

View File

@ -108,7 +108,7 @@ class MediaRepository:
self.dynamic_thumbnails = hs.config.media.dynamic_thumbnails
self.thumbnail_requirements = hs.config.media.thumbnail_requirements
self.remote_media_linearizer = Linearizer(name="media_remote")
self.remote_media_linearizer = Linearizer(name="media_remote", clock=self.clock)
self.recently_accessed_remotes: Set[Tuple[str, str]] = set()
self.recently_accessed_locals: Set[str] = set()

View File

@ -490,7 +490,7 @@ class BackgroundProcessLoggingContext(LoggingContext):
"""
if instance_id is None:
instance_id = id(self)
super().__init__("%s-%s" % (name, instance_id))
super().__init__(name="%s-%s" % (name, instance_id), server_name=server_name)
self._proc: Optional[_BackgroundProcess] = _BackgroundProcess(
desc=name, server_name=server_name, ctx=self
)

View File

@ -436,7 +436,9 @@ class FederationSenderHandler:
# to. This is always set before we use it.
self.federation_position: Optional[int] = None
self._fed_position_linearizer = Linearizer(name="_fed_position_linearizer")
self._fed_position_linearizer = Linearizer(
name="_fed_position_linearizer", clock=hs.get_clock()
)
async def process_replication_rows(
self, stream_name: str, token: int, rows: list

View File

@ -65,7 +65,7 @@ class PushRuleRestServlet(RestServlet):
hs.get_instance_name() in hs.config.worker.writers.push_rules
)
self._push_rules_handler = hs.get_push_rules_handler()
self._push_rule_linearizer = Linearizer(name="push_rules")
self._push_rule_linearizer = Linearizer(name="push_rules", clock=hs.get_clock())
async def on_PUT(self, request: SynapseRequest, path: str) -> Tuple[int, JsonDict]:
if not self._is_push_worker:

View File

@ -442,7 +442,7 @@ class HomeServer(metaclass=abc.ABCMeta):
@cache_in_self
def get_clock(self) -> Clock:
return Clock(self._reactor)
return Clock(self._reactor, server_name=self.hostname)
def get_datastores(self) -> Databases:
if not self.datastores:

View File

@ -642,7 +642,9 @@ class StateResolutionHandler:
self.server_name = hs.hostname
self.clock = hs.get_clock()
self.resolve_linearizer = Linearizer(name="state_resolve_lock")
self.resolve_linearizer = Linearizer(
name="state_resolve_lock", clock=self.clock
)
# dict of set of event_ids -> _StateCacheEntry.
self._state_cache: ExpiringCache[FrozenSet[int], _StateCacheEntry] = (

View File

@ -77,7 +77,9 @@ class StateStorageController:
# Used by `_get_joined_hosts` to ensure only one thing mutates the cache
# at a time. Keyed by room_id.
self._joined_host_linearizer = Linearizer("_JoinedHostsCache")
self._joined_host_linearizer = Linearizer(
name="_JoinedHostsCache", clock=self._clock
)
def notify_event_un_partial_stated(self, event_id: str) -> None:
self._partial_state_events_tracker.notify_un_partial_stated(event_id)

View File

@ -146,7 +146,7 @@ def make_pool(
def _on_new_connection(conn: Connection) -> None:
# Ensure we have a logging context so we can correctly track queries,
# etc.
with LoggingContext("db.on_new_connection"):
with LoggingContext(name="db.on_new_connection", server_name=server_name):
engine.on_new_connection(
LoggingDatabaseConnection(
conn=conn,
@ -1043,7 +1043,9 @@ class DatabasePool:
assert not self.engine.in_transaction(conn)
with LoggingContext(
str(curr_context), parent_context=parent_context
name=str(curr_context),
server_name=self.server_name,
parent_context=parent_context,
) as context:
with opentracing.start_active_span(
operation_name="db.connection",

View File

@ -47,7 +47,6 @@ from typing import (
Tuple,
TypeVar,
Union,
cast,
overload,
)
@ -65,7 +64,6 @@ from synapse.logging.context import (
run_coroutine_in_background,
run_in_background,
)
from synapse.types import ISynapseThreadlessReactor
from synapse.util.clock import Clock
logger = logging.getLogger(__name__)
@ -551,25 +549,20 @@ class Linearizer:
def __init__(
self,
name: Optional[str] = None,
*,
name: str,
max_count: int = 1,
clock: Optional[Clock] = None,
clock: Clock,
):
"""
Args:
name: TODO
max_count: The maximum number of concurrent accesses
clock: (ideally, the homeserver clock `hs.get_clock()`)
"""
if name is None:
self.name: Union[str, int] = id(self)
else:
self.name = name
if not clock:
from twisted.internet import reactor
clock = Clock(cast(ISynapseThreadlessReactor, reactor))
self._clock = clock
self.name = name
self.max_count = max_count
self._clock = clock
# key_to_defer is a map from the key to a _LinearizerEntry.
self.key_to_defer: Dict[Hashable, _LinearizerEntry] = {}

View File

@ -420,7 +420,7 @@ class LruCache(Generic[KT, VT]):
self,
*,
max_size: int,
server_name: Literal[None] = None,
server_name: str,
cache_name: Literal[None] = None,
cache_type: Type[Union[dict, TreeCache]] = dict,
size_callback: Optional[Callable[[VT], int]] = None,
@ -435,7 +435,7 @@ class LruCache(Generic[KT, VT]):
self,
*,
max_size: int,
server_name: Optional[str] = None,
server_name: str,
cache_name: Optional[str] = None,
cache_type: Type[Union[dict, TreeCache]] = dict,
size_callback: Optional[Callable[[VT], int]] = None,
@ -450,12 +450,10 @@ class LruCache(Generic[KT, VT]):
max_size: The maximum amount of entries the cache can hold
server_name: The homeserver name that this cache is associated with
(used to label the metric) (`hs.hostname`). Must be set if `cache_name` is
set. If unset, no metrics will be reported on this cache.
(used to label the metric) (`hs.hostname`).
cache_name: The name of this cache, for the prometheus metrics. Must be set
if `server_name` is set. If unset, no metrics will be reported on this
cache.
cache_name: The name of this cache, for the prometheus metrics. If unset, no
metrics will be reported on this cache.
cache_type:
type of underlying cache to be used. Typically one of dict
@ -497,7 +495,9 @@ class LruCache(Generic[KT, VT]):
# Default `clock` to something sensible. Note that we rename it to
# `real_clock` so that mypy doesn't think its still `Optional`.
if clock is None:
real_clock = Clock(cast(ISynapseThreadlessReactor, reactor))
real_clock = Clock(
cast(ISynapseThreadlessReactor, reactor), server_name=server_name
)
else:
real_clock = clock

View File

@ -44,6 +44,7 @@ class Clock:
"""
_reactor: ISynapseThreadlessReactor = attr.ib()
_server_name: str = attr.ib()
async def sleep(self, seconds: float) -> None:
d: defer.Deferred[float] = defer.Deferred()
@ -144,7 +145,11 @@ class Clock:
# this function and yield control back to the reactor to avoid leaking the
# current logcontext to the reactor (which would then get picked up and
# associated with the next thing the reactor does)
with context.PreserveLoggingContext(context.LoggingContext("looping_call")):
with context.PreserveLoggingContext(
context.LoggingContext(
name="looping_call", server_name=self._server_name
)
):
# We use `run_in_background` to reset the logcontext after `f` (or the
# awaitable returned by `f`) completes to avoid leaking the current
# logcontext to the reactor
@ -199,7 +204,9 @@ class Clock:
# this function and yield control back to the reactor to avoid leaking the
# current logcontext to the reactor (which would then get picked up and
# associated with the next thing the reactor does)
with context.PreserveLoggingContext(context.LoggingContext("call_later")):
with context.PreserveLoggingContext(
context.LoggingContext(name="call_later", server_name=self._server_name)
):
# We use `run_in_background` to reset the logcontext after `f` (or the
# awaitable returned by `f`) completes to avoid leaking the current
# logcontext to the reactor
@ -258,7 +265,9 @@ class Clock:
# current logcontext to the reactor (which would then get picked up and
# associated with the next thing the reactor does)
with context.PreserveLoggingContext(
context.LoggingContext("call_when_running")
context.LoggingContext(
name="call_when_running", server_name=self._server_name
)
):
# We use `run_in_background` to reset the logcontext after `f` (or the
# awaitable returned by `f`) completes to avoid leaking the current
@ -313,7 +322,11 @@ class Clock:
# this function and yield control back to the reactor to avoid leaking the
# current logcontext to the reactor (which would then get picked up and
# associated with the next thing the reactor does)
with context.PreserveLoggingContext(context.LoggingContext("system_event")):
with context.PreserveLoggingContext(
context.LoggingContext(
name="system_event", server_name=self._server_name
)
):
# We use `run_in_background` to reset the logcontext after `f` (or the
# awaitable returned by `f`) completes to avoid leaking the current
# logcontext to the reactor

View File

@ -32,6 +32,7 @@ from typing import NoReturn, Optional, Type
from synapse.logging.context import (
LoggingContext,
PreserveLoggingContext,
current_context,
)
@ -149,9 +150,12 @@ def daemonize_process(pid_file: str, logger: logging.Logger, chdir: str = "/") -
signal.signal(signal.SIGTERM, sigterm)
# Copy the `server_name` from the current logcontext
server_name = current_context().server_name
# Cleanup pid file at exit.
def exit() -> None:
with LoggingContext("atexit"):
with LoggingContext(name="atexit", server_name=server_name):
logger.warning("Stopping daemon.")
os.remove(pid_file)
sys.exit(0)

View File

@ -217,7 +217,11 @@ class Measure:
else:
assert isinstance(curr_context, LoggingContext)
parent_context = curr_context
self._logging_context = LoggingContext(str(curr_context), parent_context)
self._logging_context = LoggingContext(
name=str(curr_context),
server_name=self.server_name,
parent_context=parent_context,
)
self.start: Optional[float] = None
def __enter__(self) -> "Measure":

View File

@ -86,7 +86,7 @@ async def main(reactor: ISynapseReactor, loops: int) -> float:
hs_config = Config()
# To be able to sleep.
clock = Clock(reactor)
clock = Clock(reactor, server_name=hs_config.server.server_name)
errors = StringIO()
publisher = LogPublisher()

View File

@ -29,7 +29,9 @@ async def main(reactor: ISynapseReactor, loops: int) -> float:
"""
Benchmark `loops` number of insertions into LruCache without eviction.
"""
cache: LruCache[int, bool] = LruCache(max_size=loops)
cache: LruCache[int, bool] = LruCache(
max_size=loops, server_name="synmark_benchmark"
)
start = perf_counter()

View File

@ -30,7 +30,9 @@ async def main(reactor: ISynapseReactor, loops: int) -> float:
Benchmark `loops` number of insertions into LruCache where half of them are
evicted.
"""
cache: LruCache[int, bool] = LruCache(max_size=loops // 2)
cache: LruCache[int, bool] = LruCache(
max_size=loops // 2, server_name="synmark_benchmark"
)
start = perf_counter()

View File

@ -75,7 +75,7 @@ class CacheConfigTests(TestCase):
the default cache size in the interim, and then resized once the config
is loaded.
"""
cache: LruCache = LruCache(max_size=100)
cache: LruCache = LruCache(max_size=100, server_name="test_server")
add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor)
self.assertEqual(cache.max_size, 50)
@ -96,7 +96,7 @@ class CacheConfigTests(TestCase):
self.config.read_config(config, config_dir_path="", data_dir_path="")
self.config.resize_all_caches()
cache: LruCache = LruCache(max_size=100)
cache: LruCache = LruCache(max_size=100, server_name="test_server")
add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor)
self.assertEqual(cache.max_size, 200)
@ -106,7 +106,7 @@ class CacheConfigTests(TestCase):
the default cache size in the interim, and then resized to the new
default cache size once the config is loaded.
"""
cache: LruCache = LruCache(max_size=100)
cache: LruCache = LruCache(max_size=100, server_name="test_server")
add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor)
self.assertEqual(cache.max_size, 50)
@ -126,7 +126,7 @@ class CacheConfigTests(TestCase):
self.config.read_config(config, config_dir_path="", data_dir_path="")
self.config.resize_all_caches()
cache: LruCache = LruCache(max_size=100)
cache: LruCache = LruCache(max_size=100, server_name="test_server")
add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor)
self.assertEqual(cache.max_size, 150)
@ -145,15 +145,15 @@ class CacheConfigTests(TestCase):
self.config.read_config(config, config_dir_path="", data_dir_path="")
self.config.resize_all_caches()
cache_a: LruCache = LruCache(max_size=100)
cache_a: LruCache = LruCache(max_size=100, server_name="test_server")
add_resizable_cache("*cache_a*", cache_resize_callback=cache_a.set_cache_factor)
self.assertEqual(cache_a.max_size, 200)
cache_b: LruCache = LruCache(max_size=100)
cache_b: LruCache = LruCache(max_size=100, server_name="test_server")
add_resizable_cache("*Cache_b*", cache_resize_callback=cache_b.set_cache_factor)
self.assertEqual(cache_b.max_size, 300)
cache_c: LruCache = LruCache(max_size=100)
cache_c: LruCache = LruCache(max_size=100, server_name="test_server")
add_resizable_cache("*cache_c*", cache_resize_callback=cache_c.set_cache_factor)
self.assertEqual(cache_c.max_size, 200)
@ -169,6 +169,7 @@ class CacheConfigTests(TestCase):
cache: LruCache = LruCache(
max_size=self.config.event_cache_size,
apply_cache_factor_from_config=False,
server_name="test_server",
)
add_resizable_cache("event_cache", cache_resize_callback=cache.set_cache_factor)

View File

@ -121,7 +121,9 @@ class KeyringTestCase(unittest.HomeserverTestCase):
async def first_lookup() -> None:
with LoggingContext(
"context_11", request=cast(ContextRequest, FakeRequest("context_11"))
name="context_11",
server_name=self.hs.hostname,
request=cast(ContextRequest, FakeRequest("context_11")),
):
res_deferreds = kr.verify_json_objects_for_server(
[("server10", json1, 0), ("server11", {}, 0)]
@ -161,7 +163,9 @@ class KeyringTestCase(unittest.HomeserverTestCase):
async def second_lookup() -> None:
with LoggingContext(
"context_12", request=cast(ContextRequest, FakeRequest("context_12"))
name="context_12",
server_name=self.hs.hostname,
request=cast(ContextRequest, FakeRequest("context_12")),
):
res_deferreds_2 = kr.verify_json_objects_for_server(
[

View File

@ -229,7 +229,10 @@ class MessageAcceptTests(unittest.FederatingHomeserverTestCase):
room_version=RoomVersions.V10,
)
with LoggingContext("test-context"):
with LoggingContext(
name="test-context",
server_name=self.hs.hostname,
):
failure = self.get_failure(
self.federation_event_handler.on_receive_pdu(
self.OTHER_SERVER_NAME, lying_event

View File

@ -318,6 +318,7 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
current_depth = 1
limit = 100
# Make sure backfill still works
self.get_success(
self.hs.get_federation_handler().maybe_backfill(
@ -485,6 +486,7 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
# the auth code requires that a signature exists, but doesn't check that
# signature... go figure.
join_event.signatures[other_server] = {"x": "y"}
self.get_success(
self.hs.get_federation_event_handler().on_send_membership_event(
other_server, join_event

View File

@ -224,7 +224,10 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
self.reactor.advance(60000)
# Finally, the call under test: send the pulled event into _process_pulled_event
with LoggingContext("test"):
with LoggingContext(
name="test",
server_name=self.hs.hostname,
):
self.get_success(
self.hs.get_federation_event_handler()._process_pulled_event(
self.OTHER_SERVER_NAME, pulled_event, backfilled=False
@ -321,7 +324,10 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
)
# The function under test: try to process the pulled event
with LoggingContext("test"):
with LoggingContext(
name="test",
server_name=self.hs.hostname,
):
self.get_success(
self.hs.get_federation_event_handler()._process_pulled_event(
self.OTHER_SERVER_NAME, pulled_event, backfilled=True
@ -339,7 +345,10 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
self.assertEqual(backfill_num_attempts, 1)
# The function under test: try to process the pulled event again
with LoggingContext("test"):
with LoggingContext(
name="test",
server_name=self.hs.hostname,
):
self.get_success(
self.hs.get_federation_event_handler()._process_pulled_event(
self.OTHER_SERVER_NAME, pulled_event, backfilled=True
@ -447,7 +456,10 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
self.assertEqual(backfill_num_attempts, 1)
# The function under test: try to process the pulled event
with LoggingContext("test"):
with LoggingContext(
name="test",
server_name=self.hs.hostname,
):
self.get_success(
self.hs.get_federation_event_handler()._process_pulled_event(
self.OTHER_SERVER_NAME, pulled_event, backfilled=True
@ -602,7 +614,10 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
)
# The function under test: try to backfill and process the pulled event
with LoggingContext("test"):
with LoggingContext(
name="test",
server_name=self.hs.hostname,
):
self.get_success(
self.hs.get_federation_event_handler().backfill(
self.OTHER_SERVER_NAME,
@ -742,7 +757,10 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
}
# The function under test: try to backfill and process the pulled event
with LoggingContext("test"):
with LoggingContext(
name="test",
server_name=self.hs.hostname,
):
self.get_success(
self.hs.get_federation_event_handler().backfill(
self.OTHER_SERVER_NAME,
@ -887,7 +905,10 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
next_depth += 1
next_timestamp += 100
with LoggingContext("send_rejected_power_levels_event"):
with LoggingContext(
name="send_rejected_power_levels_event",
server_name=self.hs.hostname,
):
self.get_success(
self.hs.get_federation_event_handler()._process_pulled_event(
self.OTHER_SERVER_NAME,
@ -969,7 +990,10 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
"during state resolution. The test setup is incorrect.",
)
with LoggingContext("send_rejected_kick_event"):
with LoggingContext(
name="send_rejected_kick_event",
server_name=self.hs.hostname,
):
self.get_success(
self.hs.get_federation_event_handler()._process_pulled_event(
self.OTHER_SERVER_NAME, rejected_kick_event, backfilled=False
@ -1085,7 +1109,10 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
# We have to bump the clock a bit, to keep the retry logic in
# `FederationClient.get_pdu` happy
self.reactor.advance(60000)
with LoggingContext("send_pulled_event"):
with LoggingContext(
name="send_pulled_event",
server_name=self.hs.hostname,
):
async def get_event(
destination: str, event_id: str, timeout: Optional[int] = None

View File

@ -200,7 +200,10 @@ class MatrixFederationAgentTests(unittest.TestCase):
"""
Sends a simple GET request via the agent, and checks its logcontext management
"""
with LoggingContext("one") as context:
with LoggingContext(
name="one",
server_name="test_server",
) as context:
fetch_d: Deferred[IResponse] = self.agent.request(b"GET", uri)
# Nothing happened yet

View File

@ -52,7 +52,10 @@ class SrvResolverTestCase(unittest.TestCase):
@defer.inlineCallbacks
def do_lookup() -> Generator["Deferred[object]", object, List[Server]]:
with LoggingContext("one") as ctx:
with LoggingContext(
name="one",
server_name="test_server",
) as ctx:
resolve_d = resolver.resolve_service(service_name)
result: List[Server]
result = yield defer.ensureDeferred(resolve_d) # type: ignore[assignment]

View File

@ -502,7 +502,7 @@ def _log_for_request(request_number: int, message: str) -> None:
"""Logs a message for an iteration of `make_request_with_cancellation_test`."""
# We want consistent alignment when logging stack traces, so ensure the logging
# context has a fixed width name.
with LoggingContext(name=f"request-{request_number:<2}"):
with LoggingContext(name=f"request-{request_number:<2}", server_name="test_server"):
logger.info(message)

View File

@ -80,7 +80,10 @@ class FederationClientTests(HomeserverTestCase):
@defer.inlineCallbacks
def do_request() -> Generator["Deferred[Any]", object, object]:
with LoggingContext("one") as context:
with LoggingContext(
name="one",
server_name=self.hs.hostname,
) as context:
fetch_d = defer.ensureDeferred(
self.cl.get_json("testserv:8008", "foo/bar")
)

View File

@ -91,7 +91,7 @@ class TracingScopeTestCase(TestCase):
def test_start_active_span(self) -> None:
# the scope manager assumes a logging context of some sort.
with LoggingContext("root context"):
with LoggingContext(name="root context", server_name="test_server"):
self.assertIsNone(self._tracer.active_span)
# start_active_span should start and activate a span.
@ -115,7 +115,7 @@ class TracingScopeTestCase(TestCase):
def test_nested_spans(self) -> None:
"""Starting two spans off inside each other should work"""
with LoggingContext("root context"):
with LoggingContext(name="root context", server_name="test_server"):
with start_active_span("root span", tracer=self._tracer) as root_scope:
self.assertEqual(self._tracer.active_span, root_scope.span)
root_context = cast(jaeger_client.SpanContext, root_scope.span.context)
@ -164,7 +164,8 @@ class TracingScopeTestCase(TestCase):
# Reactor/Clock interfaces), via inheritance from
# `twisted.internet.testing.MemoryReactor` and `twisted.internet.testing.Clock`
clock = Clock(
reactor # type: ignore[arg-type]
reactor, # type: ignore[arg-type]
server_name="test_server",
)
scopes = []
@ -200,7 +201,7 @@ class TracingScopeTestCase(TestCase):
self.assertEqual(self._tracer.active_span, root_scope.span)
with LoggingContext("root context"):
with LoggingContext(name="root context", server_name="test_server"):
# start the test off
d1 = defer.ensureDeferred(root())
@ -234,7 +235,8 @@ class TracingScopeTestCase(TestCase):
# Reactor/Clock interfaces), via inheritance from
# `twisted.internet.testing.MemoryReactor` and `twisted.internet.testing.Clock`
clock = Clock(
reactor # type: ignore[arg-type]
reactor, # type: ignore[arg-type]
server_name="test_server",
)
scope_map: Dict[str, opentracing.Scope] = {}
@ -314,7 +316,7 @@ class TracingScopeTestCase(TestCase):
# We shouldn't see any active spans outside of the scope
self.assertIsNone(self._tracer.active_span)
with LoggingContext("root context"):
with LoggingContext(name="root context", server_name="test_server"):
# Start the test off
d_root = defer.ensureDeferred(root())
@ -357,7 +359,7 @@ class TracingScopeTestCase(TestCase):
Test whether we can use `@trace_with_opname` (`@trace`) and `@tag_args`
with sync functions
"""
with LoggingContext("root context"):
with LoggingContext(name="root context", server_name="test_server"):
@trace_with_opname("fixture_sync_func", tracer=self._tracer)
@tag_args
@ -378,7 +380,7 @@ class TracingScopeTestCase(TestCase):
Test whether we can use `@trace_with_opname` (`@trace`) and `@tag_args`
with functions that return deferreds
"""
with LoggingContext("root context"):
with LoggingContext(name="root context", server_name="test_server"):
@trace_with_opname("fixture_deferred_func", tracer=self._tracer)
@tag_args
@ -402,7 +404,7 @@ class TracingScopeTestCase(TestCase):
Test whether we can use `@trace_with_opname` (`@trace`) and `@tag_args`
with async functions
"""
with LoggingContext("root context"):
with LoggingContext(name="root context", server_name="test_server"):
@trace_with_opname("fixture_async_func", tracer=self._tracer)
@tag_args
@ -424,7 +426,7 @@ class TracingScopeTestCase(TestCase):
Test whether we can use `@trace_with_opname` (`@trace`) and `@tag_args`
with functions that return an awaitable (e.g. a coroutine)
"""
with LoggingContext("root context"):
with LoggingContext(name="root context", server_name="test_server"):
# Something we can return without `await` to get a coroutine
async def fixture_async_func() -> str:
return "foo"

View File

@ -63,13 +63,13 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
log = self.get_log_line()
# The terse logger should give us these keys.
expected_log_keys = [
expected_log_keys = {
"log",
"time",
"level",
"namespace",
]
self.assertCountEqual(log.keys(), expected_log_keys)
}
self.assertIncludes(log.keys(), expected_log_keys, exact=True)
self.assertEqual(log["log"], "Hello there, wally!")
def test_extra_data(self) -> None:
@ -87,7 +87,7 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
log = self.get_log_line()
# The terse logger should give us these keys.
expected_log_keys = [
expected_log_keys = {
"log",
"time",
"level",
@ -96,8 +96,8 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
"foo",
"int",
"bool",
]
self.assertCountEqual(log.keys(), expected_log_keys)
}
self.assertIncludes(log.keys(), expected_log_keys, exact=True)
# Check the values of the extra fields.
self.assertEqual(log["foo"], "bar")
@ -117,12 +117,12 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
log = self.get_log_line()
# The terse logger should give us these keys.
expected_log_keys = [
expected_log_keys = {
"log",
"level",
"namespace",
]
self.assertCountEqual(log.keys(), expected_log_keys)
}
self.assertIncludes(log.keys(), expected_log_keys, exact=True)
self.assertEqual(log["log"], "Hello there, wally!")
def test_with_context(self) -> None:
@ -134,19 +134,20 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
handler.addFilter(LoggingContextFilter())
logger = self.get_logger(handler)
with LoggingContext("name"):
with LoggingContext(name="name", server_name="test_server"):
logger.info("Hello there, %s!", "wally")
log = self.get_log_line()
# The terse logger should give us these keys.
expected_log_keys = [
expected_log_keys = {
"log",
"level",
"namespace",
"request",
]
self.assertCountEqual(log.keys(), expected_log_keys)
"server_name",
}
self.assertIncludes(log.keys(), expected_log_keys, exact=True)
self.assertEqual(log["log"], "Hello there, wally!")
self.assertEqual(log["request"], "name")
@ -187,14 +188,16 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
request.requester = "@foo:test"
with LoggingContext(
request.get_request_id(), parent_context=request.logcontext
name=request.get_request_id(),
server_name="test_server",
parent_context=request.logcontext,
):
logger.info("Hello there, %s!", "wally")
log = self.get_log_line()
# The terse logger includes additional request information, if possible.
expected_log_keys = [
expected_log_keys = {
"log",
"level",
"namespace",
@ -207,8 +210,9 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
"url",
"protocol",
"user_agent",
]
self.assertCountEqual(log.keys(), expected_log_keys)
"server_name",
}
self.assertIncludes(log.keys(), expected_log_keys, exact=True)
self.assertEqual(log["log"], "Hello there, wally!")
self.assertTrue(log["request"].startswith("POST-"))
self.assertEqual(log["ip_address"], "127.0.0.1")
@ -236,14 +240,14 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
log = self.get_log_line()
# The terse logger should give us these keys.
expected_log_keys = [
expected_log_keys = {
"log",
"level",
"namespace",
"exc_type",
"exc_value",
]
self.assertCountEqual(log.keys(), expected_log_keys)
}
self.assertIncludes(log.keys(), expected_log_keys, exact=True)
self.assertEqual(log["log"], "Hello there, wally!")
self.assertEqual(log["exc_type"], "ValueError")
self.assertEqual(log["exc_value"], "That's wrong, you wally!")

View File

@ -90,12 +90,14 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
) -> Generator["defer.Deferred[Any]", object, None]:
@defer.inlineCallbacks
def cb() -> Generator["defer.Deferred[object]", object, Tuple[int, JsonDict]]:
yield defer.ensureDeferred(Clock(reactor).sleep(0))
yield defer.ensureDeferred(
Clock(reactor, server_name="test_server").sleep(0)
)
return 1, {}
@defer.inlineCallbacks
def test() -> Generator["defer.Deferred[Any]", object, None]:
with LoggingContext("c") as c1:
with LoggingContext(name="c", server_name="test_server") as c1:
res = yield self.cache.fetch_or_execute_request(
self.mock_request, self.mock_requester, cb
)
@ -125,7 +127,7 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
called[0] = True
raise Exception("boo")
with LoggingContext("test") as test_context:
with LoggingContext(name="test", server_name="test_server") as test_context:
try:
yield self.cache.fetch_or_execute_request(
self.mock_request, self.mock_requester, cb
@ -157,7 +159,7 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
called[0] = True
return defer.fail(Exception("boo"))
with LoggingContext("test") as test_context:
with LoggingContext(name="test", server_name="test_server") as test_context:
try:
yield self.cache.fetch_or_execute_request(
self.mock_request, self.mock_requester, cb

View File

@ -787,7 +787,7 @@ class ThreadPool:
def get_clock() -> Tuple[ThreadedMemoryReactorClock, Clock]:
clock = ThreadedMemoryReactorClock()
hs_clock = Clock(clock)
hs_clock = Clock(clock, server_name="test_server")
return clock, hs_clock

View File

@ -76,7 +76,7 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase):
self.event_ids.append(event.event_id)
def test_simple(self) -> None:
with LoggingContext(name="test") as ctx:
with LoggingContext(name="test", server_name=self.hs.hostname) as ctx:
res = self.get_success(
self.store.have_seen_events(
self.room_id, [self.event_ids[0], "eventdoesnotexist"]
@ -88,7 +88,7 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase):
self.assertEqual(ctx.get_resource_usage().db_txn_count, 1)
# a second lookup of the same events should cause no queries
with LoggingContext(name="test") as ctx:
with LoggingContext(name="test", server_name=self.hs.hostname) as ctx:
res = self.get_success(
self.store.have_seen_events(
self.room_id, [self.event_ids[0], "eventdoesnotexist"]
@ -113,7 +113,7 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase):
)
)
with LoggingContext(name="test") as ctx:
with LoggingContext(name="test", server_name=self.hs.hostname) as ctx:
# First, check `have_seen_event` for an event we have not seen yet
# to prime the cache with a `false` value.
res = self.get_success(
@ -135,7 +135,7 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase):
)
)
with LoggingContext(name="test") as ctx:
with LoggingContext(name="test", server_name=self.hs.hostname) as ctx:
# Check `have_seen_event` again and we should see the updated fact
# that we have now seen the event after persisting it.
res = self.get_success(
@ -166,7 +166,7 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase):
res = self.store._get_event_cache.get_local((event.event_id,))
self.assertEqual(res, None, "Event was cached when it should not have been.")
with LoggingContext(name="test") as ctx:
with LoggingContext(name="test", server_name=self.hs.hostname) as ctx:
# Persist the event which should invalidate then prefill the
# `_get_event_cache` so we don't return stale values.
# Side Note: Apparently, persisting an event isn't a transaction in the
@ -200,7 +200,7 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase):
Test to make sure that all events associated with the given `(room_id,)`
are invalidated in the `have_seen_event` cache.
"""
with LoggingContext(name="test") as ctx:
with LoggingContext(name="test", server_name=self.hs.hostname) as ctx:
# Prime the cache with some values
res = self.get_success(
self.store.have_seen_events(self.room_id, self.event_ids)
@ -213,7 +213,7 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase):
# Clear the cache with any events associated with the `room_id`
self.store.have_seen_event.invalidate((self.room_id,))
with LoggingContext(name="test") as ctx:
with LoggingContext(name="test", server_name=self.hs.hostname) as ctx:
res = self.get_success(
self.store.have_seen_events(self.room_id, self.event_ids)
)
@ -249,7 +249,7 @@ class EventCacheTestCase(unittest.HomeserverTestCase):
def test_simple(self) -> None:
"""Test that we cache events that we pull from the DB."""
with LoggingContext("test") as ctx:
with LoggingContext(name="test", server_name=self.hs.hostname) as ctx:
self.get_success(self.store.get_event(self.event_id))
# We should have fetched the event from the DB
@ -263,7 +263,7 @@ class EventCacheTestCase(unittest.HomeserverTestCase):
# Reset the event cache
self.store._get_event_cache.clear()
with LoggingContext("test") as ctx:
with LoggingContext(name="test", server_name=self.hs.hostname) as ctx:
# We keep hold of the event event though we never use it.
event = self.get_success(self.store.get_event(self.event_id)) # noqa: F841
@ -273,7 +273,7 @@ class EventCacheTestCase(unittest.HomeserverTestCase):
# Reset the event cache
self.store._get_event_cache.clear()
with LoggingContext("test") as ctx:
with LoggingContext(name="test", server_name=self.hs.hostname) as ctx:
self.get_success(self.store.get_event(self.event_id))
# Since the event is still in memory we shouldn't have fetched it
@ -285,7 +285,7 @@ class EventCacheTestCase(unittest.HomeserverTestCase):
out once.
"""
with LoggingContext("test") as ctx:
with LoggingContext(name="test", server_name=self.hs.hostname) as ctx:
d = yieldable_gather_results(
self.store.get_event, [self.event_id, self.event_id]
)
@ -531,8 +531,8 @@ class GetEventCancellationTestCase(unittest.HomeserverTestCase):
"runWithConnection",
new=runWithConnection,
):
ctx1 = LoggingContext("get_event1")
ctx2 = LoggingContext("get_event2")
ctx1 = LoggingContext(name="get_event1", server_name=self.hs.hostname)
ctx2 = LoggingContext(name="get_event2", server_name=self.hs.hostname)
async def get_event(ctx: LoggingContext) -> None:
with ctx:

View File

@ -72,15 +72,10 @@ class LockTestCase(unittest.HomeserverTestCase):
release_lock.callback(None)
# Run the tasks to completion.
# To work around `Linearizer`s using a different reactor to sleep when
# contended (https://github.com/matrix-org/synapse/issues/12841), we call
# `runUntilCurrent` on `twisted.internet.reactor`, which is a different
# reactor to that used by the homeserver.
assert isinstance(reactor, ReactorBase)
self.get_success(task1)
reactor.runUntilCurrent()
self.pump()
self.get_success(task2)
reactor.runUntilCurrent()
self.pump()
self.get_success(task3)
# At most one task should have held the lock at a time.
@ -223,15 +218,11 @@ class ReadWriteLockTestCase(unittest.HomeserverTestCase):
release_lock.callback(None)
# Run the tasks to completion.
# To work around `Linearizer`s using a different reactor to sleep when
# contended (https://github.com/matrix-org/synapse/issues/12841), we call
# `runUntilCurrent` on `twisted.internet.reactor`, which is a different
# reactor to that used by the homeserver.
assert isinstance(reactor, ReactorBase)
self.get_success(task1)
reactor.runUntilCurrent()
self.pump()
self.get_success(task2)
reactor.runUntilCurrent()
self.pump()
self.get_success(task3)
# At most one task should have held the lock at a time.
@ -275,15 +266,11 @@ class ReadWriteLockTestCase(unittest.HomeserverTestCase):
release_lock.callback(None)
# Run the tasks to completion.
# To work around `Linearizer`s using a different reactor to sleep when
# contended (https://github.com/matrix-org/synapse/issues/12841), we call
# `runUntilCurrent` on `twisted.internet.reactor`, which is a different
# reactor to that used by the homeserver.
assert isinstance(reactor, ReactorBase)
self.get_success(task1)
reactor.runUntilCurrent()
self.pump()
self.get_success(task2)
reactor.runUntilCurrent()
self.pump()
self.get_success(task3)
# At most one task should have held the lock at a time.

View File

@ -634,7 +634,7 @@ class HomeserverTestCase(TestCase):
)
def setup_test_homeserver(
self, name: Optional[str] = None, **kwargs: Any
self, server_name: Optional[str] = None, **kwargs: Any
) -> HomeServer:
"""
Set up the test homeserver, meant to be called by the overridable
@ -656,8 +656,8 @@ class HomeserverTestCase(TestCase):
# The server name can be specified using either the `name` argument or a config
# override. The `name` argument takes precedence over any config overrides.
if name is not None:
config["server_name"] = name
if server_name is not None:
config["server_name"] = server_name
# Parse the config from a config dict into a HomeServerConfig
config_obj = make_homeserver_config_obj(config)
@ -666,10 +666,11 @@ class HomeserverTestCase(TestCase):
# The server name in the config is now `name`, if provided, or the `server_name`
# from a config override, or the default of "test". Whichever it is, we
# construct a homeserver with a matching name.
kwargs["name"] = config_obj.server.server_name
server_name = config_obj.server.server_name
kwargs["name"] = server_name
async def run_bg_updates() -> None:
with LoggingContext("run_bg_updates"):
with LoggingContext(name="run_bg_updates", server_name=server_name):
self.get_success(stor.db_pool.updates.run_background_updates(False))
hs = setup_test_homeserver(self.addCleanup, **kwargs)

View File

@ -306,7 +306,7 @@ class DescriptorTestCase(unittest.TestCase):
@defer.inlineCallbacks
def do_lookup() -> Generator["Deferred[Any]", object, int]:
with LoggingContext("c1") as c1:
with LoggingContext(name="c1", server_name="test_server") as c1:
r = yield obj.fn(1)
self.assertEqual(current_context(), c1)
return cast(int, r)
@ -350,7 +350,7 @@ class DescriptorTestCase(unittest.TestCase):
@defer.inlineCallbacks
def do_lookup() -> Generator["Deferred[object]", object, None]:
with LoggingContext("c1") as c1:
with LoggingContext(name="c1", server_name="test_server") as c1:
try:
d = obj.fn(1)
self.assertEqual(
@ -547,7 +547,7 @@ class DescriptorTestCase(unittest.TestCase):
obj = Cls()
async def do_lookup() -> None:
with LoggingContext("c1") as c1:
with LoggingContext(name="c1", server_name="test_server") as c1:
try:
await obj.fn(123)
self.fail("No CancelledError thrown")
@ -843,7 +843,7 @@ class CachedListDescriptorTestCase(unittest.TestCase):
assert context.name == "c1"
return self.mock(args1, arg2)
with LoggingContext("c1") as c1:
with LoggingContext(name="c1", server_name="test_server") as c1:
obj = Cls()
obj.mock.return_value = {10: "fish", 20: "chips"}
@ -1025,7 +1025,7 @@ class CachedListDescriptorTestCase(unittest.TestCase):
obj = Cls()
async def do_lookup() -> None:
with LoggingContext("c1") as c1:
with LoggingContext(name="c1", server_name="test_server") as c1:
try:
await obj.list_fn([123])
self.fail("No CancelledError thrown")

View File

@ -223,7 +223,7 @@ class TimeoutDeferredTest(TestCase):
incomplete_d: Deferred = Deferred()
incomplete_d.addErrback(mark_was_cancelled)
with LoggingContext("one") as context_one:
with LoggingContext(name="one", server_name="test_server") as context_one:
timing_out_d = timeout_deferred(
deferred=incomplete_d,
timeout=1.0,
@ -536,7 +536,7 @@ class DelayCancellationTests(TestCase):
await make_deferred_yieldable(blocking_d)
async def outer() -> None:
with LoggingContext("c") as c:
with LoggingContext(name="c", server_name="test_server") as c:
try:
await delay_cancellation(inner())
self.fail("`CancelledError` was not raised")
@ -651,7 +651,7 @@ class GatherCoroutineTests(TestCase):
def test_single(self) -> None:
"Test passing in a single coroutine works"
with LoggingContext("test_ctx") as text_ctx:
with LoggingContext(name="test_ctx", server_name="test_server") as text_ctx:
deferred: "defer.Deferred[None]"
coroutine, deferred = self.make_coroutine()
@ -677,7 +677,7 @@ class GatherCoroutineTests(TestCase):
def test_multiple_resolve(self) -> None:
"Test passing in multiple coroutine that all resolve works"
with LoggingContext("test_ctx") as test_ctx:
with LoggingContext(name="test_ctx", server_name="test_server") as test_ctx:
deferred1: "defer.Deferred[int]"
coroutine1, deferred1 = self.make_coroutine()
deferred2: "defer.Deferred[str]"
@ -710,7 +710,7 @@ class GatherCoroutineTests(TestCase):
def test_multiple_fail(self) -> None:
"Test passing in multiple coroutine where one fails does the right thing"
with LoggingContext("test_ctx") as test_ctx:
with LoggingContext(name="test_ctx", server_name="test_server") as test_ctx:
deferred1: "defer.Deferred[int]"
coroutine1, deferred1 = self.make_coroutine()
deferred2: "defer.Deferred[str]"

View File

@ -21,14 +21,16 @@
from typing import Hashable, Protocol, Tuple
from twisted.internet import defer, reactor
from twisted.internet.base import ReactorBase
from twisted.internet import defer
from twisted.internet.defer import CancelledError, Deferred
from synapse.logging.context import LoggingContext, current_context
from synapse.util.async_helpers import Linearizer
from tests import unittest
from tests.server import (
get_clock,
)
class UnblockFunction(Protocol):
@ -36,6 +38,9 @@ class UnblockFunction(Protocol):
class LinearizerTestCase(unittest.TestCase):
def setUp(self) -> None:
self.reactor, self.clock = get_clock()
def _start_task(
self, linearizer: Linearizer, key: Hashable
) -> Tuple["Deferred[None]", "Deferred[None]", UnblockFunction]:
@ -73,13 +78,12 @@ class LinearizerTestCase(unittest.TestCase):
def _pump(self) -> None:
"""Pump the reactor to advance `Linearizer`s."""
assert isinstance(reactor, ReactorBase)
while reactor.getDelayedCalls():
reactor.runUntilCurrent()
while self.reactor.getDelayedCalls():
self.reactor.pump([0] * 100)
def test_linearizer(self) -> None:
"""Tests that a task is queued up behind an earlier task."""
linearizer = Linearizer()
linearizer = Linearizer(name="test_linearizer", clock=self.clock)
key = object()
@ -100,7 +104,7 @@ class LinearizerTestCase(unittest.TestCase):
Runs through the same scenario as `test_linearizer`.
"""
linearizer = Linearizer()
linearizer = Linearizer(name="test_linearizer", clock=self.clock)
key = object()
@ -131,11 +135,11 @@ class LinearizerTestCase(unittest.TestCase):
The stack should *not* explode when the slow thing completes.
"""
linearizer = Linearizer()
linearizer = Linearizer(name="test_linearizer", clock=self.clock)
key = ""
async def func(i: int) -> None:
with LoggingContext("func(%s)" % i) as lc:
with LoggingContext(name="func(%s)" % i, server_name="test_server") as lc:
async with linearizer.queue(key):
self.assertEqual(current_context(), lc)
@ -151,24 +155,24 @@ class LinearizerTestCase(unittest.TestCase):
def test_multiple_entries(self) -> None:
"""Tests a `Linearizer` with a concurrency above 1."""
limiter = Linearizer(max_count=3)
linearizer = Linearizer(name="test_linearizer", max_count=3, clock=self.clock)
key = object()
_, acquired_d1, unblock1 = self._start_task(limiter, key)
_, acquired_d1, unblock1 = self._start_task(linearizer, key)
self.assertTrue(acquired_d1.called)
_, acquired_d2, unblock2 = self._start_task(limiter, key)
_, acquired_d2, unblock2 = self._start_task(linearizer, key)
self.assertTrue(acquired_d2.called)
_, acquired_d3, unblock3 = self._start_task(limiter, key)
_, acquired_d3, unblock3 = self._start_task(linearizer, key)
self.assertTrue(acquired_d3.called)
# These next two tasks have to wait.
_, acquired_d4, unblock4 = self._start_task(limiter, key)
_, acquired_d4, unblock4 = self._start_task(linearizer, key)
self.assertFalse(acquired_d4.called)
_, acquired_d5, unblock5 = self._start_task(limiter, key)
_, acquired_d5, unblock5 = self._start_task(linearizer, key)
self.assertFalse(acquired_d5.called)
# Once the first task completes, the fourth task can continue.
@ -186,13 +190,13 @@ class LinearizerTestCase(unittest.TestCase):
unblock5()
# The next task shouldn't have to wait.
_, acquired_d6, unblock6 = self._start_task(limiter, key)
_, acquired_d6, unblock6 = self._start_task(linearizer, key)
self.assertTrue(acquired_d6)
unblock6()
def test_cancellation(self) -> None:
"""Tests cancellation while waiting for a `Linearizer`."""
linearizer = Linearizer()
linearizer = Linearizer(name="test_linearizer", clock=self.clock)
key = object()
@ -226,7 +230,7 @@ class LinearizerTestCase(unittest.TestCase):
def test_cancellation_during_sleep(self) -> None:
"""Tests cancellation during the sleep just after waiting for a `Linearizer`."""
linearizer = Linearizer()
linearizer = Linearizer(name="test_linearizer", clock=self.clock)
key = object()

View File

@ -58,7 +58,7 @@ class LoggingContextTestCase(unittest.TestCase):
@logcontext_clean
def test_with_context(self) -> None:
with LoggingContext("test"):
with LoggingContext(name="test", server_name="test_server"):
self._check_test_key("test")
@logcontext_clean
@ -66,7 +66,7 @@ class LoggingContextTestCase(unittest.TestCase):
"""
Test `Clock.sleep`
"""
clock = Clock(reactor)
clock = Clock(reactor, server_name="test_server")
# Sanity check that we start in the sentinel context
self._check_test_key("sentinel")
@ -80,7 +80,7 @@ class LoggingContextTestCase(unittest.TestCase):
# other words, another task shouldn't have leaked their context to us.
self._check_test_key("sentinel")
with LoggingContext("competing"):
with LoggingContext(name="competing", server_name="test_server"):
await clock.sleep(0)
self._check_test_key("competing")
@ -92,7 +92,7 @@ class LoggingContextTestCase(unittest.TestCase):
reactor.callLater(0, lambda: defer.ensureDeferred(competing_callback()))
with LoggingContext("foo"):
with LoggingContext(name="foo", server_name="test_server"):
await clock.sleep(0)
self._check_test_key("foo")
await clock.sleep(0)
@ -111,7 +111,7 @@ class LoggingContextTestCase(unittest.TestCase):
"""
Test `Clock.looping_call`
"""
clock = Clock(reactor)
clock = Clock(reactor, server_name="test_server")
# Sanity check that we start in the sentinel context
self._check_test_key("sentinel")
@ -125,7 +125,7 @@ class LoggingContextTestCase(unittest.TestCase):
# which server spawned this loop and which server the logs came from.
self._check_test_key("looping_call")
with LoggingContext("competing"):
with LoggingContext(name="competing", server_name="test_server"):
await clock.sleep(0)
self._check_test_key("competing")
@ -135,7 +135,7 @@ class LoggingContextTestCase(unittest.TestCase):
# so that the test can complete and we see the underlying error.
callback_finished = True
with LoggingContext("foo"):
with LoggingContext(name="foo", server_name="test_server"):
lc = clock.looping_call(
lambda: defer.ensureDeferred(competing_callback()), 0
)
@ -161,7 +161,7 @@ class LoggingContextTestCase(unittest.TestCase):
"""
Test `Clock.looping_call_now`
"""
clock = Clock(reactor)
clock = Clock(reactor, server_name="test_server")
# Sanity check that we start in the sentinel context
self._check_test_key("sentinel")
@ -175,7 +175,7 @@ class LoggingContextTestCase(unittest.TestCase):
# which server spawned this loop and which server the logs came from.
self._check_test_key("looping_call")
with LoggingContext("competing"):
with LoggingContext(name="competing", server_name="test_server"):
await clock.sleep(0)
self._check_test_key("competing")
@ -185,7 +185,7 @@ class LoggingContextTestCase(unittest.TestCase):
# so that the test can complete and we see the underlying error.
callback_finished = True
with LoggingContext("foo"):
with LoggingContext(name="foo", server_name="test_server"):
lc = clock.looping_call_now(
lambda: defer.ensureDeferred(competing_callback()), 0
)
@ -209,7 +209,7 @@ class LoggingContextTestCase(unittest.TestCase):
"""
Test `Clock.call_later`
"""
clock = Clock(reactor)
clock = Clock(reactor, server_name="test_server")
# Sanity check that we start in the sentinel context
self._check_test_key("sentinel")
@ -223,7 +223,7 @@ class LoggingContextTestCase(unittest.TestCase):
# which server spawned this loop and which server the logs came from.
self._check_test_key("call_later")
with LoggingContext("competing"):
with LoggingContext(name="competing", server_name="test_server"):
await clock.sleep(0)
self._check_test_key("competing")
@ -233,7 +233,7 @@ class LoggingContextTestCase(unittest.TestCase):
# so that the test can complete and we see the underlying error.
callback_finished = True
with LoggingContext("foo"):
with LoggingContext(name="foo", server_name="test_server"):
clock.call_later(0, lambda: defer.ensureDeferred(competing_callback()))
self._check_test_key("foo")
await clock.sleep(0)
@ -261,7 +261,7 @@ class LoggingContextTestCase(unittest.TestCase):
`d.callback(None)` without anything else. See the *Deferred callbacks* section
of docs/log_contexts.md for more details.
"""
clock = Clock(reactor)
clock = Clock(reactor, server_name="test_server")
# Sanity check that we start in the sentinel context
self._check_test_key("sentinel")
@ -274,7 +274,7 @@ class LoggingContextTestCase(unittest.TestCase):
# The deferred callback should have the same logcontext as the caller
self._check_test_key("foo")
with LoggingContext("competing"):
with LoggingContext(name="competing", server_name="test_server"):
await clock.sleep(0)
self._check_test_key("competing")
@ -284,7 +284,7 @@ class LoggingContextTestCase(unittest.TestCase):
# so that the test can complete and we see the underlying error.
callback_finished = True
with LoggingContext("foo"):
with LoggingContext(name="foo", server_name="test_server"):
d: defer.Deferred[None] = defer.Deferred()
d.addCallback(lambda _: defer.ensureDeferred(competing_callback()))
self._check_test_key("foo")
@ -318,7 +318,7 @@ class LoggingContextTestCase(unittest.TestCase):
`d.callback(None)` without anything else. See the *Deferred callbacks* section
of docs/log_contexts.md for more details.
"""
clock = Clock(reactor)
clock = Clock(reactor, server_name="test_server")
# Sanity check that we start in the sentinel context
self._check_test_key("sentinel")
@ -331,7 +331,7 @@ class LoggingContextTestCase(unittest.TestCase):
# The deferred callback should have the same logcontext as the caller
self._check_test_key("sentinel")
with LoggingContext("competing"):
with LoggingContext(name="competing", server_name="test_server"):
await clock.sleep(0)
self._check_test_key("competing")
@ -341,7 +341,7 @@ class LoggingContextTestCase(unittest.TestCase):
# so that the test can complete and we see the underlying error.
callback_finished = True
with LoggingContext("foo"):
with LoggingContext(name="foo", server_name="test_server"):
d: defer.Deferred[None] = defer.Deferred()
d.addCallback(lambda _: defer.ensureDeferred(competing_callback()))
self._check_test_key("foo")
@ -379,7 +379,7 @@ class LoggingContextTestCase(unittest.TestCase):
`d.callback(None)` without anything else. See the *Deferred callbacks* section
of docs/log_contexts.md for more details.
"""
clock = Clock(reactor)
clock = Clock(reactor, server_name="test_server")
# Sanity check that we start in the sentinel context
self._check_test_key("sentinel")
@ -392,7 +392,7 @@ class LoggingContextTestCase(unittest.TestCase):
# The deferred callback should have the same logcontext as the caller
self._check_test_key("foo")
with LoggingContext("competing"):
with LoggingContext(name="competing", server_name="test_server"):
await clock.sleep(0)
self._check_test_key("competing")
@ -409,7 +409,9 @@ class LoggingContextTestCase(unittest.TestCase):
# context manager lifetime methods of `LoggingContext` (`__enter__`/`__exit__`).
# And we can still set the current logcontext by using `PreserveLoggingContext`
# and passing in the "foo" logcontext.
with PreserveLoggingContext(LoggingContext("foo")):
with PreserveLoggingContext(
LoggingContext(name="foo", server_name="test_server")
):
d: defer.Deferred[None] = defer.Deferred()
d.addCallback(lambda _: defer.ensureDeferred(competing_callback()))
self._check_test_key("foo")
@ -448,14 +450,14 @@ class LoggingContextTestCase(unittest.TestCase):
self._check_test_key("sentinel")
async def _test_run_in_background(self, function: Callable[[], object]) -> None:
clock = Clock(reactor)
clock = Clock(reactor, server_name="test_server")
# Sanity check that we start in the sentinel context
self._check_test_key("sentinel")
callback_finished = False
with LoggingContext("foo"):
with LoggingContext(name="foo", server_name="test_server"):
# Fire off the function, but don't wait on it.
deferred = run_in_background(function)
self._check_test_key("foo")
@ -490,7 +492,7 @@ class LoggingContextTestCase(unittest.TestCase):
@logcontext_clean
async def test_run_in_background_with_blocking_fn(self) -> None:
async def blocking_function() -> None:
await Clock(reactor).sleep(0)
await Clock(reactor, server_name="test_server").sleep(0)
await self._test_run_in_background(blocking_function)
@ -523,7 +525,7 @@ class LoggingContextTestCase(unittest.TestCase):
async def testfunc() -> None:
self._check_test_key("foo")
d = defer.ensureDeferred(Clock(reactor).sleep(0))
d = defer.ensureDeferred(Clock(reactor, server_name="test_server").sleep(0))
self.assertIs(current_context(), SENTINEL_CONTEXT)
await d
self._check_test_key("foo")
@ -552,7 +554,7 @@ class LoggingContextTestCase(unittest.TestCase):
This will stress the logic around incomplete deferreds in `run_coroutine_in_background`.
"""
clock = Clock(reactor)
clock = Clock(reactor, server_name="test_server")
# Sanity check that we start in the sentinel context
self._check_test_key("sentinel")
@ -565,7 +567,7 @@ class LoggingContextTestCase(unittest.TestCase):
# The callback should have the same logcontext as the caller
self._check_test_key("foo")
with LoggingContext("competing"):
with LoggingContext(name="competing", server_name="test_server"):
await clock.sleep(0)
self._check_test_key("competing")
@ -575,7 +577,7 @@ class LoggingContextTestCase(unittest.TestCase):
# so that the test can complete and we see the underlying error.
callback_finished = True
with LoggingContext("foo"):
with LoggingContext(name="foo", server_name="test_server"):
run_coroutine_in_background(competing_callback())
self._check_test_key("foo")
await clock.sleep(0)
@ -608,7 +610,7 @@ class LoggingContextTestCase(unittest.TestCase):
# The callback should have the same logcontext as the caller
self._check_test_key("foo")
with LoggingContext("competing"):
with LoggingContext(name="competing", server_name="test_server"):
# We `await` here but there is nothing to wait for here since the
# deferred is already complete so we should immediately continue
# executing in the same context.
@ -622,7 +624,7 @@ class LoggingContextTestCase(unittest.TestCase):
# so that the test can complete and we see the underlying error.
callback_finished = True
with LoggingContext("foo"):
with LoggingContext(name="foo", server_name="test_server"):
run_coroutine_in_background(competing_callback())
self._check_test_key("foo")
@ -648,7 +650,7 @@ class LoggingContextTestCase(unittest.TestCase):
sentinel_context = current_context()
with LoggingContext("foo"):
with LoggingContext(name="foo", server_name="test_server"):
d1 = make_deferred_yieldable(blocking_function())
# make sure that the context was reset by make_deferred_yieldable
self.assertIs(current_context(), sentinel_context)
@ -665,7 +667,7 @@ class LoggingContextTestCase(unittest.TestCase):
) -> Generator["defer.Deferred[object]", object, None]:
sentinel_context = current_context()
with LoggingContext("foo"):
with LoggingContext(name="foo", server_name="test_server"):
d1 = make_deferred_yieldable(_chained_deferred_function())
# make sure that the context was reset by make_deferred_yieldable
self.assertIs(current_context(), sentinel_context)
@ -677,7 +679,7 @@ class LoggingContextTestCase(unittest.TestCase):
@logcontext_clean
def test_nested_logging_context(self) -> None:
with LoggingContext("foo"):
with LoggingContext(name="foo", server_name="test_server"):
nested_context = nested_logging_context(suffix="bar")
self.assertEqual(nested_context.name, "foo-bar")

View File

@ -34,13 +34,13 @@ from tests.unittest import override_config
class LruCacheTestCase(unittest.HomeserverTestCase):
def test_get_set(self) -> None:
cache: LruCache[str, str] = LruCache(max_size=1)
cache: LruCache[str, str] = LruCache(max_size=1, server_name="test_server")
cache["key"] = "value"
self.assertEqual(cache.get("key"), "value")
self.assertEqual(cache["key"], "value")
def test_eviction(self) -> None:
cache: LruCache[int, int] = LruCache(max_size=2)
cache: LruCache[int, int] = LruCache(max_size=2, server_name="test_server")
cache[1] = 1
cache[2] = 2
@ -54,7 +54,7 @@ class LruCacheTestCase(unittest.HomeserverTestCase):
self.assertEqual(cache.get(3), 3)
def test_setdefault(self) -> None:
cache: LruCache[str, int] = LruCache(max_size=1)
cache: LruCache[str, int] = LruCache(max_size=1, server_name="test_server")
self.assertEqual(cache.setdefault("key", 1), 1)
self.assertEqual(cache.get("key"), 1)
self.assertEqual(cache.setdefault("key", 2), 1)
@ -63,7 +63,7 @@ class LruCacheTestCase(unittest.HomeserverTestCase):
self.assertEqual(cache.get("key"), 2)
def test_pop(self) -> None:
cache: LruCache[str, int] = LruCache(max_size=1)
cache: LruCache[str, int] = LruCache(max_size=1, server_name="test_server")
cache["key"] = 1
self.assertEqual(cache.pop("key"), 1)
self.assertEqual(cache.pop("key"), None)
@ -71,7 +71,7 @@ class LruCacheTestCase(unittest.HomeserverTestCase):
def test_del_multi(self) -> None:
# The type here isn't quite correct as they don't handle TreeCache well.
cache: LruCache[Tuple[str, str], str] = LruCache(
max_size=4, cache_type=TreeCache
max_size=4, cache_type=TreeCache, server_name="test_server"
)
cache[("animal", "cat")] = "mew"
cache[("animal", "dog")] = "woof"
@ -91,7 +91,7 @@ class LruCacheTestCase(unittest.HomeserverTestCase):
# Man from del_multi say "Yes".
def test_clear(self) -> None:
cache: LruCache[str, int] = LruCache(max_size=1)
cache: LruCache[str, int] = LruCache(max_size=1, server_name="test_server")
cache["key"] = 1
cache.clear()
self.assertEqual(len(cache), 0)
@ -107,7 +107,7 @@ class LruCacheTestCase(unittest.HomeserverTestCase):
class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
def test_get(self) -> None:
m = Mock()
cache: LruCache[str, str] = LruCache(max_size=1)
cache: LruCache[str, str] = LruCache(max_size=1, server_name="test_server")
cache.set("key", "value")
self.assertFalse(m.called)
@ -126,7 +126,7 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
def test_multi_get(self) -> None:
m = Mock()
cache: LruCache[str, str] = LruCache(max_size=1)
cache: LruCache[str, str] = LruCache(max_size=1, server_name="test_server")
cache.set("key", "value")
self.assertFalse(m.called)
@ -145,7 +145,7 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
def test_set(self) -> None:
m = Mock()
cache: LruCache[str, str] = LruCache(max_size=1)
cache: LruCache[str, str] = LruCache(max_size=1, server_name="test_server")
cache.set("key", "value", callbacks=[m])
self.assertFalse(m.called)
@ -161,7 +161,7 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
def test_pop(self) -> None:
m = Mock()
cache: LruCache[str, str] = LruCache(max_size=1)
cache: LruCache[str, str] = LruCache(max_size=1, server_name="test_server")
cache.set("key", "value", callbacks=[m])
self.assertFalse(m.called)
@ -182,7 +182,7 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
m4 = Mock()
# The type here isn't quite correct as they don't handle TreeCache well.
cache: LruCache[Tuple[str, str], str] = LruCache(
max_size=4, cache_type=TreeCache
max_size=4, cache_type=TreeCache, server_name="test_server"
)
cache.set(("a", "1"), "value", callbacks=[m1])
@ -205,7 +205,7 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
def test_clear(self) -> None:
m1 = Mock()
m2 = Mock()
cache: LruCache[str, str] = LruCache(max_size=5)
cache: LruCache[str, str] = LruCache(max_size=5, server_name="test_server")
cache.set("key1", "value", callbacks=[m1])
cache.set("key2", "value", callbacks=[m2])
@ -222,7 +222,7 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
m1 = Mock(name="m1")
m2 = Mock(name="m2")
m3 = Mock(name="m3")
cache: LruCache[str, str] = LruCache(max_size=2)
cache: LruCache[str, str] = LruCache(max_size=2, server_name="test_server")
cache.set("key1", "value", callbacks=[m1])
cache.set("key2", "value", callbacks=[m2])
@ -258,7 +258,9 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
class LruCacheSizedTestCase(unittest.HomeserverTestCase):
def test_evict(self) -> None:
cache: LruCache[str, List[int]] = LruCache(max_size=5, size_callback=len)
cache: LruCache[str, List[int]] = LruCache(
max_size=5, size_callback=len, server_name="test_server"
)
cache["key1"] = [0]
cache["key2"] = [1, 2]
cache["key3"] = [3]
@ -282,7 +284,7 @@ class LruCacheSizedTestCase(unittest.HomeserverTestCase):
def test_zero_size_drop_from_cache(self) -> None:
"""Test that `drop_from_cache` works correctly with 0-sized entries."""
cache: LruCache[str, List[int]] = LruCache(
max_size=5, size_callback=lambda x: 0
max_size=5, size_callback=lambda x: 0, server_name="test_server"
)
cache["key1"] = []
@ -307,7 +309,9 @@ class TimeEvictionTestCase(unittest.HomeserverTestCase):
def test_evict(self) -> None:
setup_expire_lru_cache_entries(self.hs)
cache: LruCache[str, int] = LruCache(max_size=5, clock=self.hs.get_clock())
cache: LruCache[str, int] = LruCache(
max_size=5, server_name="test_server", clock=self.hs.get_clock()
)
# Check that we evict entries we haven't accessed for 30 minutes.
cache["key1"] = 1
@ -359,7 +363,9 @@ class MemoryEvictionTestCase(unittest.HomeserverTestCase):
mock_jemalloc_class.get_stat.return_value = 924288000
setup_expire_lru_cache_entries(self.hs)
cache: LruCache[str, int] = LruCache(max_size=4, clock=self.hs.get_clock())
cache: LruCache[str, int] = LruCache(
max_size=4, server_name="test_server", clock=self.hs.get_clock()
)
cache["key1"] = 1
cache["key2"] = 2
@ -396,7 +402,7 @@ class MemoryEvictionTestCase(unittest.HomeserverTestCase):
class ExtraIndexLruCacheTestCase(unittest.HomeserverTestCase):
def test_invalidate_simple(self) -> None:
cache: LruCache[str, int] = LruCache(
max_size=10, extra_index_cb=lambda k, v: str(v)
max_size=10, server_name="test_server", extra_index_cb=lambda k, v: str(v)
)
cache["key1"] = 1
cache["key2"] = 2
@ -411,7 +417,7 @@ class ExtraIndexLruCacheTestCase(unittest.HomeserverTestCase):
def test_invalidate_multi(self) -> None:
cache: LruCache[str, int] = LruCache(
max_size=10, extra_index_cb=lambda k, v: str(v)
max_size=10, server_name="test_server", extra_index_cb=lambda k, v: str(v)
)
cache["key1"] = 1
cache["key2"] = 1