synapse/tests/util/test_async_helpers.py
Eric Eastwood 5143f93dc9
Fix server_name in logging context for multiple Synapse instances in one process (#18868)
### Background

As part of Element's plan to support a light form of vhosting (virtual
host) (multiple instances of Synapse in the same Python process), we're
currently diving into the details and implications of running multiple
instances of Synapse in the same Python process.

"Per-tenant logging" tracked internally by
https://github.com/element-hq/synapse-small-hosts/issues/48

### Prior art

Previously, we exposed `server_name` by providing a static logging
`MetadataFilter` that injected the values:


205d9e4fc4/synapse/config/logger.py (L216)

While this can work fine for the normal case of one Synapse instance per
Python process, this configures things globally and isn't compatible
when we try to start multiple Synapse instances because each subsequent
tenant will overwrite the previous tenant.


### What does this PR do?

We remove the `MetadataFilter` and replace it by tracking the
`server_name` in the `LoggingContext` and expose it with our existing
[`LoggingContextFilter`](205d9e4fc4/synapse/logging/context.py (L584-L622))
that we already use to expose information about the `request`.

This means that the `server_name` value follows wherever we log as
expected even when we have multiple Synapse instances running in the
same process.


### A note on logcontext

Anywhere, Synapse mistakenly uses the `sentinel` logcontext to log
something, we won't know which server sent the log. We've been fixing up
`sentinel` logcontext usage as tracked by
https://github.com/element-hq/synapse/issues/18905

Any further `sentinel` logcontext usage we find in the future can be
fixed piecemeal as normal.


d2a966f922/docs/log_contexts.md (L71-L81)


### Testing strategy

1. Adjust your logging config to include `%(server_name)s` in the format
    ```yaml
    formatters:
        precise:
format: '%(asctime)s - %(server_name)s - %(name)s - %(lineno)d -
%(levelname)s - %(request)s - %(message)s'
    ```
1. Start Synapse: `poetry run synapse_homeserver --config-path
homeserver.yaml`
1. Make some requests (`curl
http://localhost:8008/_matrix/client/versions`, etc)
1. Open the homeserver logs and notice the `server_name` in the logs as
expected. `unknown_server_from_sentinel_context` is expected for the
`sentinel` logcontext (things outside of Synapse).
2025-09-26 17:10:48 -05:00

738 lines
26 KiB
Python

#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# [This file includes modifications made by New Vector Limited]
#
#
import logging
import traceback
from typing import Any, Coroutine, List, NoReturn, Optional, Tuple, TypeVar
from parameterized import parameterized_class
from twisted.internet import defer
from twisted.internet.defer import CancelledError, Deferred, ensureDeferred
from twisted.internet.task import Clock
from twisted.python.failure import Failure
from synapse.logging.context import (
SENTINEL_CONTEXT,
LoggingContext,
PreserveLoggingContext,
current_context,
make_deferred_yieldable,
)
from synapse.util.async_helpers import (
AwakenableSleeper,
ObservableDeferred,
concurrently_execute,
delay_cancellation,
gather_optional_coroutines,
stop_cancellation,
timeout_deferred,
)
from tests.server import get_clock
from tests.unittest import TestCase
logger = logging.getLogger(__name__)
T = TypeVar("T")
class ObservableDeferredTest(TestCase):
def test_succeed(self) -> None:
origin_d: "Deferred[int]" = Deferred()
observable = ObservableDeferred(origin_d)
observer1 = observable.observe()
observer2 = observable.observe()
self.assertFalse(observer1.called)
self.assertFalse(observer2.called)
# check the first observer is called first
def check_called_first(res: int) -> int:
self.assertFalse(observer2.called)
return res
observer1.addBoth(check_called_first)
# store the results
results: List[Optional[int]] = [None, None]
def check_val(res: int, idx: int) -> int:
results[idx] = res
return res
observer1.addCallback(check_val, 0)
observer2.addCallback(check_val, 1)
origin_d.callback(123)
self.assertEqual(results[0], 123, "observer 1 callback result")
self.assertEqual(results[1], 123, "observer 2 callback result")
def test_failure(self) -> None:
origin_d: Deferred = Deferred()
observable = ObservableDeferred(origin_d, consumeErrors=True)
observer1 = observable.observe()
observer2 = observable.observe()
self.assertFalse(observer1.called)
self.assertFalse(observer2.called)
# check the first observer is called first
def check_called_first(res: int) -> int:
self.assertFalse(observer2.called)
return res
observer1.addBoth(check_called_first)
# store the results
results: List[Optional[Failure]] = [None, None]
def check_failure(res: Failure, idx: int) -> None:
results[idx] = res
return None
observer1.addErrback(check_failure, 0)
observer2.addErrback(check_failure, 1)
try:
raise Exception("gah!")
except Exception as e:
origin_d.errback(e)
assert results[0] is not None
self.assertEqual(str(results[0].value), "gah!", "observer 1 errback result")
assert results[1] is not None
self.assertEqual(str(results[1].value), "gah!", "observer 2 errback result")
def test_cancellation(self) -> None:
"""Test that cancelling an observer does not affect other observers."""
origin_d: "Deferred[int]" = Deferred()
observable = ObservableDeferred(origin_d, consumeErrors=True)
observer1 = observable.observe()
observer2 = observable.observe()
observer3 = observable.observe()
self.assertFalse(observer1.called)
self.assertFalse(observer2.called)
self.assertFalse(observer3.called)
# cancel the second observer
observer2.cancel()
self.assertFalse(observer1.called)
self.failureResultOf(observer2, CancelledError)
self.assertFalse(observer3.called)
# other observers resolve as normal
origin_d.callback(123)
self.assertEqual(observer1.result, 123, "observer 1 callback result")
self.assertEqual(observer3.result, 123, "observer 3 callback result")
# additional observers resolve as normal
observer4 = observable.observe()
self.assertEqual(observer4.result, 123, "observer 4 callback result")
class TimeoutDeferredTest(TestCase):
def setUp(self) -> None:
self.clock = Clock()
def test_times_out(self) -> None:
"""Basic test case that checks that the original deferred is cancelled and that
the timing-out deferred is errbacked
"""
cancelled = False
def canceller(_d: Deferred) -> None:
nonlocal cancelled
cancelled = True
non_completing_d: Deferred = Deferred(canceller)
timing_out_d = timeout_deferred(non_completing_d, 1.0, self.clock)
self.assertNoResult(timing_out_d)
self.assertFalse(cancelled, "deferred was cancelled prematurely")
self.clock.pump((1.0,))
self.assertTrue(cancelled, "deferred was not cancelled by timeout")
self.failureResultOf(timing_out_d, defer.TimeoutError)
def test_times_out_when_canceller_throws(self) -> None:
"""Test that we have successfully worked around
https://twistedmatrix.com/trac/ticket/9534"""
def canceller(_d: Deferred) -> None:
raise Exception("can't cancel this deferred")
non_completing_d: Deferred = Deferred(canceller)
timing_out_d = timeout_deferred(non_completing_d, 1.0, self.clock)
self.assertNoResult(timing_out_d)
self.clock.pump((1.0,))
self.failureResultOf(timing_out_d, defer.TimeoutError)
async def test_logcontext_is_preserved_on_cancellation(self) -> None:
# Sanity check that we start in the sentinel context
self.assertEqual(current_context(), SENTINEL_CONTEXT)
incomplete_deferred_was_cancelled = False
def mark_was_cancelled(res: Failure) -> None:
"""
A passthrough errback which sets `incomplete_deferred_was_cancelled`.
This means we re-raise any exception and allows further errbacks (in
`timeout_deferred(...)`) to do their thing. Just trying to be a transparent
proxy of any exception while doing our internal test book-keeping.
"""
nonlocal incomplete_deferred_was_cancelled
if res.check(CancelledError):
incomplete_deferred_was_cancelled = True
else:
logger.error(
"Expected incomplete_d to fail with `CancelledError` because our "
"`timeout_deferred(...)` utility canceled it but saw %s",
res,
)
# Re-raise the exception so that any further errbacks can do their thing as
# normal
res.raiseException()
# Create a deferred which we will never complete
incomplete_d: Deferred = Deferred()
incomplete_d.addErrback(mark_was_cancelled)
with LoggingContext(name="one", server_name="test_server") as context_one:
timing_out_d = timeout_deferred(
deferred=incomplete_d,
timeout=1.0,
reactor=self.clock,
)
self.assertNoResult(timing_out_d)
# We should still be in the logcontext we started in
self.assertIs(current_context(), context_one)
# Pump the reactor until we trigger the timeout
#
# We're manually pumping the reactor (and causing any pending callbacks to
# be called) so we need to be in the sentinel logcontext to avoid leaking
# our current logcontext into the reactor (which would then get picked up
# and associated with the next thing the reactor does). `with
# PreserveLoggingContext()` will reset the logcontext to the sentinel while
# we're pumping the reactor in the block and return us back to our current
# logcontext after the block.
with PreserveLoggingContext():
self.clock.pump(
# We only need to pump `1.0` (seconds) as we set
# `timeout_deferred(timeout=1.0)` above
(1.0,)
)
# We expect the incomplete deferred to have been cancelled because of the
# timeout by this point
self.assertTrue(
incomplete_deferred_was_cancelled,
"incomplete deferred was not cancelled",
)
# We should see the `TimeoutError` (instead of a `CancelledError`)
self.failureResultOf(timing_out_d, defer.TimeoutError)
# We're still in the same logcontext
self.assertIs(current_context(), context_one)
# Back to the sentinel context
self.assertEqual(current_context(), SENTINEL_CONTEXT)
class _TestException(Exception):
pass
class ConcurrentlyExecuteTest(TestCase):
def test_limits_runners(self) -> None:
"""If we have more tasks than runners, we should get the limit of runners"""
started = 0
waiters = []
processed = []
async def callback(v: int) -> None:
# when we first enter, bump the start count
nonlocal started
started += 1
# record the fact we got an item
processed.append(v)
# wait for the goahead before returning
d2: "Deferred[int]" = Deferred()
waiters.append(d2)
await d2
# set it going
d2 = ensureDeferred(concurrently_execute(callback, [1, 2, 3, 4, 5], 3))
# check we got exactly 3 processes
self.assertEqual(started, 3)
self.assertEqual(len(waiters), 3)
# let one finish
waiters.pop().callback(0)
# ... which should start another
self.assertEqual(started, 4)
self.assertEqual(len(waiters), 3)
# we still shouldn't be done
self.assertNoResult(d2)
# finish the job
while waiters:
waiters.pop().callback(0)
# check everything got done
self.assertEqual(started, 5)
self.assertCountEqual(processed, [1, 2, 3, 4, 5])
self.successResultOf(d2)
def test_preserves_stacktraces(self) -> None:
"""Test that the stacktrace from an exception thrown in the callback is preserved"""
d1: "Deferred[int]" = Deferred()
async def callback(v: int) -> None:
# alas, this doesn't work at all without an await here
await d1
raise _TestException("bah")
async def caller() -> None:
try:
await concurrently_execute(callback, [1], 2)
except _TestException as e:
tb = traceback.extract_tb(e.__traceback__)
# we expect to see "caller", "concurrently_execute" and "callback".
self.assertEqual(tb[0].name, "caller")
self.assertEqual(tb[1].name, "concurrently_execute")
self.assertEqual(tb[-1].name, "callback")
else:
self.fail("No exception thrown")
d2 = ensureDeferred(caller())
d1.callback(0)
self.successResultOf(d2)
def test_preserves_stacktraces_on_preformed_failure(self) -> None:
"""Test that the stacktrace on a Failure returned by the callback is preserved"""
d1: "Deferred[int]" = Deferred()
f = Failure(_TestException("bah"))
async def callback(v: int) -> None:
# alas, this doesn't work at all without an await here
await d1
await defer.fail(f)
async def caller() -> None:
try:
await concurrently_execute(callback, [1], 2)
except _TestException as e:
tb = traceback.extract_tb(e.__traceback__)
# Remove twisted internals from the stack, as we don't care
# about the precise details.
tb = traceback.StackSummary(
t for t in tb if "/twisted/" not in t.filename
)
# we expect to see "caller", "concurrently_execute" at the top of the stack
self.assertEqual(tb[0].name, "caller")
self.assertEqual(tb[1].name, "concurrently_execute")
# ... some stack frames from the implementation of `concurrently_execute` ...
# and at the bottom of the stack we expect to see "callback"
self.assertEqual(tb[-1].name, "callback")
else:
self.fail("No exception thrown")
d2 = ensureDeferred(caller())
d1.callback(0)
self.successResultOf(d2)
@parameterized_class(
("wrapper",),
[("stop_cancellation",), ("delay_cancellation",)],
)
class CancellationWrapperTests(TestCase):
"""Common tests for the `stop_cancellation` and `delay_cancellation` functions."""
wrapper: str
def wrap_deferred(self, deferred: "Deferred[str]") -> "Deferred[str]":
if self.wrapper == "stop_cancellation":
return stop_cancellation(deferred)
elif self.wrapper == "delay_cancellation":
return delay_cancellation(deferred)
else:
raise ValueError(f"Unsupported wrapper type: {self.wrapper}")
def test_succeed(self) -> None:
"""Test that the new `Deferred` receives the result."""
deferred: "Deferred[str]" = Deferred()
wrapper_deferred = self.wrap_deferred(deferred)
# Success should propagate through.
deferred.callback("success")
self.assertTrue(wrapper_deferred.called)
self.assertEqual("success", self.successResultOf(wrapper_deferred))
def test_failure(self) -> None:
"""Test that the new `Deferred` receives the `Failure`."""
deferred: "Deferred[str]" = Deferred()
wrapper_deferred = self.wrap_deferred(deferred)
# Failure should propagate through.
deferred.errback(ValueError("abc"))
self.assertTrue(wrapper_deferred.called)
self.failureResultOf(wrapper_deferred, ValueError)
self.assertIsNone(deferred.result, "`Failure` was not consumed")
class StopCancellationTests(TestCase):
"""Tests for the `stop_cancellation` function."""
def test_cancellation(self) -> None:
"""Test that cancellation of the new `Deferred` leaves the original running."""
deferred: "Deferred[str]" = Deferred()
wrapper_deferred = stop_cancellation(deferred)
# Cancel the new `Deferred`.
wrapper_deferred.cancel()
self.assertTrue(wrapper_deferred.called)
self.failureResultOf(wrapper_deferred, CancelledError)
self.assertFalse(
deferred.called, "Original `Deferred` was unexpectedly cancelled"
)
# Now make the original `Deferred` fail.
# The `Failure` must be consumed, otherwise unwanted tracebacks will be printed
# in logs.
deferred.errback(ValueError("abc"))
self.assertIsNone(deferred.result, "`Failure` was not consumed")
class DelayCancellationTests(TestCase):
"""Tests for the `delay_cancellation` function."""
def test_deferred_cancellation(self) -> None:
"""Test that cancellation of the new `Deferred` waits for the original."""
deferred: "Deferred[str]" = Deferred()
wrapper_deferred = delay_cancellation(deferred)
# Cancel the new `Deferred`.
wrapper_deferred.cancel()
self.assertNoResult(wrapper_deferred)
self.assertFalse(
deferred.called, "Original `Deferred` was unexpectedly cancelled"
)
# Now make the original `Deferred` fail.
# The `Failure` must be consumed, otherwise unwanted tracebacks will be printed
# in logs.
deferred.errback(ValueError("abc"))
self.assertIsNone(deferred.result, "`Failure` was not consumed")
# Now that the original `Deferred` has failed, we should get a `CancelledError`.
self.failureResultOf(wrapper_deferred, CancelledError)
def test_coroutine_cancellation(self) -> None:
"""Test that cancellation of the new `Deferred` waits for the original."""
blocking_deferred: "Deferred[None]" = Deferred()
completion_deferred: "Deferred[None]" = Deferred()
async def task() -> NoReturn:
await blocking_deferred
completion_deferred.callback(None)
# Raise an exception. Twisted should consume it, otherwise unwanted
# tracebacks will be printed in logs.
raise ValueError("abc")
wrapper_deferred = delay_cancellation(task())
# Cancel the new `Deferred`.
wrapper_deferred.cancel()
self.assertNoResult(wrapper_deferred)
self.assertFalse(
blocking_deferred.called, "Cancellation was propagated too deep"
)
self.assertFalse(completion_deferred.called)
# Unblock the task.
blocking_deferred.callback(None)
self.assertTrue(completion_deferred.called)
# Now that the original coroutine has failed, we should get a `CancelledError`.
self.failureResultOf(wrapper_deferred, CancelledError)
def test_suppresses_second_cancellation(self) -> None:
"""Test that a second cancellation is suppressed.
Identical to `test_cancellation` except the new `Deferred` is cancelled twice.
"""
deferred: "Deferred[str]" = Deferred()
wrapper_deferred = delay_cancellation(deferred)
# Cancel the new `Deferred`, twice.
wrapper_deferred.cancel()
wrapper_deferred.cancel()
self.assertNoResult(wrapper_deferred)
self.assertFalse(
deferred.called, "Original `Deferred` was unexpectedly cancelled"
)
# Now make the original `Deferred` fail.
# The `Failure` must be consumed, otherwise unwanted tracebacks will be printed
# in logs.
deferred.errback(ValueError("abc"))
self.assertIsNone(deferred.result, "`Failure` was not consumed")
# Now that the original `Deferred` has failed, we should get a `CancelledError`.
self.failureResultOf(wrapper_deferred, CancelledError)
def test_propagates_cancelled_error(self) -> None:
"""Test that a `CancelledError` from the original `Deferred` gets propagated."""
deferred: "Deferred[str]" = Deferred()
wrapper_deferred = delay_cancellation(deferred)
# Fail the original `Deferred` with a `CancelledError`.
cancelled_error = CancelledError()
deferred.errback(cancelled_error)
# The new `Deferred` should fail with exactly the same `CancelledError`.
self.assertTrue(wrapper_deferred.called)
self.assertIs(cancelled_error, self.failureResultOf(wrapper_deferred).value)
def test_preserves_logcontext(self) -> None:
"""Test that logging contexts are preserved."""
blocking_d: "Deferred[None]" = Deferred()
async def inner() -> None:
await make_deferred_yieldable(blocking_d)
async def outer() -> None:
with LoggingContext(name="c", server_name="test_server") as c:
try:
await delay_cancellation(inner())
self.fail("`CancelledError` was not raised")
except CancelledError:
self.assertEqual(c, current_context())
# Succeed with no error, unless the logging context is wrong.
# Run and block inside `inner()`.
d = defer.ensureDeferred(outer())
self.assertEqual(SENTINEL_CONTEXT, current_context())
d.cancel()
# Now unblock. `outer()` will consume the `CancelledError` and check the
# logging context.
blocking_d.callback(None)
self.successResultOf(d)
class AwakenableSleeperTests(TestCase):
"Tests AwakenableSleeper"
def test_sleep(self) -> None:
reactor, _ = get_clock()
sleeper = AwakenableSleeper(reactor)
d = defer.ensureDeferred(sleeper.sleep("name", 1000))
reactor.pump([0.0])
self.assertFalse(d.called)
reactor.advance(0.5)
self.assertFalse(d.called)
reactor.advance(0.6)
self.assertTrue(d.called)
def test_explicit_wake(self) -> None:
reactor, _ = get_clock()
sleeper = AwakenableSleeper(reactor)
d = defer.ensureDeferred(sleeper.sleep("name", 1000))
reactor.pump([0.0])
self.assertFalse(d.called)
reactor.advance(0.5)
self.assertFalse(d.called)
sleeper.wake("name")
self.assertTrue(d.called)
reactor.advance(0.6)
def test_multiple_sleepers_timeout(self) -> None:
reactor, _ = get_clock()
sleeper = AwakenableSleeper(reactor)
d1 = defer.ensureDeferred(sleeper.sleep("name", 1000))
reactor.advance(0.6)
self.assertFalse(d1.called)
# Add another sleeper
d2 = defer.ensureDeferred(sleeper.sleep("name", 1000))
# Only the first sleep should time out now.
reactor.advance(0.6)
self.assertTrue(d1.called)
self.assertFalse(d2.called)
reactor.advance(0.6)
self.assertTrue(d2.called)
def test_multiple_sleepers_wake(self) -> None:
reactor, _ = get_clock()
sleeper = AwakenableSleeper(reactor)
d1 = defer.ensureDeferred(sleeper.sleep("name", 1000))
reactor.advance(0.5)
self.assertFalse(d1.called)
# Add another sleeper
d2 = defer.ensureDeferred(sleeper.sleep("name", 1000))
# Neither should fire yet
reactor.advance(0.3)
self.assertFalse(d1.called)
self.assertFalse(d2.called)
# Explicitly waking both up works
sleeper.wake("name")
self.assertTrue(d1.called)
self.assertTrue(d2.called)
class GatherCoroutineTests(TestCase):
"""Tests for `gather_optional_coroutines`"""
def make_coroutine(self) -> Tuple[Coroutine[Any, Any, T], "defer.Deferred[T]"]:
"""Returns a coroutine and a deferred that it is waiting on to resolve"""
d: "defer.Deferred[T]" = defer.Deferred()
async def inner() -> T:
with PreserveLoggingContext():
return await d
return inner(), d
def test_single(self) -> None:
"Test passing in a single coroutine works"
with LoggingContext(name="test_ctx", server_name="test_server") as text_ctx:
deferred: "defer.Deferred[None]"
coroutine, deferred = self.make_coroutine()
gather_deferred = defer.ensureDeferred(
gather_optional_coroutines(coroutine)
)
# We shouldn't have a result yet, and should be in the sentinel
# context.
self.assertNoResult(gather_deferred)
self.assertEqual(current_context(), SENTINEL_CONTEXT)
# Resolving the deferred will resolve the coroutine
deferred.callback(None)
# All coroutines have resolved, and so we should have the results
result = self.successResultOf(gather_deferred)
self.assertEqual(result, (None,))
# We should be back in the normal context.
self.assertEqual(current_context(), text_ctx)
def test_multiple_resolve(self) -> None:
"Test passing in multiple coroutine that all resolve works"
with LoggingContext(name="test_ctx", server_name="test_server") as test_ctx:
deferred1: "defer.Deferred[int]"
coroutine1, deferred1 = self.make_coroutine()
deferred2: "defer.Deferred[str]"
coroutine2, deferred2 = self.make_coroutine()
gather_deferred = defer.ensureDeferred(
gather_optional_coroutines(coroutine1, coroutine2)
)
# We shouldn't have a result yet, and should be in the sentinel
# context.
self.assertNoResult(gather_deferred)
self.assertEqual(current_context(), SENTINEL_CONTEXT)
# Even if we resolve one of the coroutines, we shouldn't have a result
# yet
deferred2.callback("test")
self.assertNoResult(gather_deferred)
self.assertEqual(current_context(), SENTINEL_CONTEXT)
deferred1.callback(1)
# All coroutines have resolved, and so we should have the results
result = self.successResultOf(gather_deferred)
self.assertEqual(result, (1, "test"))
# We should be back in the normal context.
self.assertEqual(current_context(), test_ctx)
def test_multiple_fail(self) -> None:
"Test passing in multiple coroutine where one fails does the right thing"
with LoggingContext(name="test_ctx", server_name="test_server") as test_ctx:
deferred1: "defer.Deferred[int]"
coroutine1, deferred1 = self.make_coroutine()
deferred2: "defer.Deferred[str]"
coroutine2, deferred2 = self.make_coroutine()
gather_deferred = defer.ensureDeferred(
gather_optional_coroutines(coroutine1, coroutine2)
)
# We shouldn't have a result yet, and should be in the sentinel
# context.
self.assertNoResult(gather_deferred)
self.assertEqual(current_context(), SENTINEL_CONTEXT)
# Throw an exception in one of the coroutines
exc = Exception("test")
deferred2.errback(exc)
# Expect the gather deferred to immediately fail
result_exc = self.failureResultOf(gather_deferred)
self.assertEqual(result_exc.value, exc)
# We should be back in the normal context.
self.assertEqual(current_context(), test_ctx)