From 20615115fba0002a534d96e8e616251cdf3632af Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 5 Aug 2025 09:30:52 +0100 Subject: [PATCH] Make `.sleep(..)` return a coroutine (#18772) This helps ensure that mypy can catch places where we don't await on it, like in #18763. --------- Co-authored-by: Eric Eastwood --- changelog.d/18772.misc | 1 + synapse/state/v2.py | 2 +- synapse/util/__init__.py | 8 ++------ tests/rest/client/test_transactions.py | 2 +- tests/server_notices/__init__.py | 2 +- tests/state/test_v2.py | 4 ++-- tests/util/test_logcontext.py | 19 ++++++++----------- 7 files changed, 16 insertions(+), 22 deletions(-) create mode 100644 changelog.d/18772.misc diff --git a/changelog.d/18772.misc b/changelog.d/18772.misc new file mode 100644 index 000000000..39ceacfd7 --- /dev/null +++ b/changelog.d/18772.misc @@ -0,0 +1 @@ +Make `Clock.sleep(..)` return a coroutine, so that mypy can catch places where we don't await on it. diff --git a/synapse/state/v2.py b/synapse/state/v2.py index 69df9eb77..44b191d4e 100644 --- a/synapse/state/v2.py +++ b/synapse/state/v2.py @@ -52,7 +52,7 @@ class Clock(Protocol): # This is usually synapse.util.Clock, but it's replaced with a FakeClock in tests. # We only ever sleep(0) though, so that other async functions can make forward # progress without waiting for stateres to complete. - def sleep(self, duration_ms: float) -> Awaitable[None]: ... + async def sleep(self, duration_ms: float) -> None: ... class StateResolutionStore(Protocol): diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index bd4d20acc..36129c3a6 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -27,7 +27,6 @@ from typing import ( Any, Callable, Dict, - Generator, Iterator, Mapping, Optional, @@ -42,7 +41,6 @@ from matrix_common.versionstring import get_distribution_version_string from typing_extensions import ParamSpec from twisted.internet import defer, task -from twisted.internet.defer import Deferred from twisted.internet.interfaces import IDelayedCall, IReactorTime from twisted.internet.task import LoopingCall from twisted.python.failure import Failure @@ -121,13 +119,11 @@ class Clock: _reactor: IReactorTime = attr.ib() - @defer.inlineCallbacks - def sleep(self, seconds: float) -> "Generator[Deferred[float], Any, Any]": + async def sleep(self, seconds: float) -> None: d: defer.Deferred[float] = defer.Deferred() with context.PreserveLoggingContext(): self._reactor.callLater(seconds, d.callback, seconds) - res = yield d - return res + await d def time(self) -> float: """Returns the current system time in seconds since epoch.""" diff --git a/tests/rest/client/test_transactions.py b/tests/rest/client/test_transactions.py index af1eecbb3..5f42acb39 100644 --- a/tests/rest/client/test_transactions.py +++ b/tests/rest/client/test_transactions.py @@ -90,7 +90,7 @@ class HttpTransactionCacheTestCase(unittest.TestCase): ) -> Generator["defer.Deferred[Any]", object, None]: @defer.inlineCallbacks def cb() -> Generator["defer.Deferred[object]", object, Tuple[int, JsonDict]]: - yield Clock(reactor).sleep(0) + yield defer.ensureDeferred(Clock(reactor).sleep(0)) return 1, {} @defer.inlineCallbacks diff --git a/tests/server_notices/__init__.py b/tests/server_notices/__init__.py index b962da0dd..1d23a126d 100644 --- a/tests/server_notices/__init__.py +++ b/tests/server_notices/__init__.py @@ -131,7 +131,7 @@ class ServerNoticesTests(unittest.HomeserverTestCase): break # Sleep and try again. - self.clock.sleep(0.1) + self.get_success(self.clock.sleep(0.1)) else: self.fail( f"Failed to join the server notices room. No 'join' field in sync_body['rooms']: {sync_body['rooms']}" diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py index 3e8f9a435..5a0096d8c 100644 --- a/tests/state/test_v2.py +++ b/tests/state/test_v2.py @@ -65,8 +65,8 @@ ORIGIN_SERVER_TS = 0 class FakeClock: - def sleep(self, msec: float) -> "defer.Deferred[None]": - return defer.succeed(None) + async def sleep(self, msec: float) -> None: + return None class FakeEvent: diff --git a/tests/util/test_logcontext.py b/tests/util/test_logcontext.py index f7c5f5fac..af36e685d 100644 --- a/tests/util/test_logcontext.py +++ b/tests/util/test_logcontext.py @@ -51,20 +51,18 @@ class LoggingContextTestCase(unittest.TestCase): with LoggingContext("test"): self._check_test_key("test") - @defer.inlineCallbacks - def test_sleep(self) -> Generator["defer.Deferred[object]", object, None]: + async def test_sleep(self) -> None: clock = Clock(reactor) - @defer.inlineCallbacks - def competing_callback() -> Generator["defer.Deferred[object]", object, None]: + async def competing_callback() -> None: with LoggingContext("competing"): - yield clock.sleep(0) + await clock.sleep(0) self._check_test_key("competing") - reactor.callLater(0, competing_callback) + reactor.callLater(0, lambda: defer.ensureDeferred(competing_callback())) with LoggingContext("one"): - yield clock.sleep(0) + await clock.sleep(0) self._check_test_key("one") def _test_run_in_background(self, function: Callable[[], object]) -> defer.Deferred: @@ -108,9 +106,8 @@ class LoggingContextTestCase(unittest.TestCase): return d2 def test_run_in_background_with_blocking_fn(self) -> defer.Deferred: - @defer.inlineCallbacks - def blocking_function() -> Generator["defer.Deferred[object]", object, None]: - yield Clock(reactor).sleep(0) + async def blocking_function() -> None: + await Clock(reactor).sleep(0) return self._test_run_in_background(blocking_function) @@ -133,7 +130,7 @@ class LoggingContextTestCase(unittest.TestCase): def test_run_in_background_with_coroutine(self) -> defer.Deferred: async def testfunc() -> None: self._check_test_key("one") - d = Clock(reactor).sleep(0) + d = defer.ensureDeferred(Clock(reactor).sleep(0)) self.assertIs(current_context(), SENTINEL_CONTEXT) await d self._check_test_key("one")