Make .sleep(..) return a coroutine (#18772)
This helps ensure that mypy can catch places where we don't await on it, like in #18763. --------- Co-authored-by: Eric Eastwood <erice@element.io>
This commit is contained in:
parent
ddbcd859aa
commit
20615115fb
1
changelog.d/18772.misc
Normal file
1
changelog.d/18772.misc
Normal file
@ -0,0 +1 @@
|
||||
Make `Clock.sleep(..)` return a coroutine, so that mypy can catch places where we don't await on it.
|
||||
@ -52,7 +52,7 @@ class Clock(Protocol):
|
||||
# This is usually synapse.util.Clock, but it's replaced with a FakeClock in tests.
|
||||
# We only ever sleep(0) though, so that other async functions can make forward
|
||||
# progress without waiting for stateres to complete.
|
||||
def sleep(self, duration_ms: float) -> Awaitable[None]: ...
|
||||
async def sleep(self, duration_ms: float) -> None: ...
|
||||
|
||||
|
||||
class StateResolutionStore(Protocol):
|
||||
|
||||
@ -27,7 +27,6 @@ from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Generator,
|
||||
Iterator,
|
||||
Mapping,
|
||||
Optional,
|
||||
@ -42,7 +41,6 @@ from matrix_common.versionstring import get_distribution_version_string
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from twisted.internet import defer, task
|
||||
from twisted.internet.defer import Deferred
|
||||
from twisted.internet.interfaces import IDelayedCall, IReactorTime
|
||||
from twisted.internet.task import LoopingCall
|
||||
from twisted.python.failure import Failure
|
||||
@ -121,13 +119,11 @@ class Clock:
|
||||
|
||||
_reactor: IReactorTime = attr.ib()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def sleep(self, seconds: float) -> "Generator[Deferred[float], Any, Any]":
|
||||
async def sleep(self, seconds: float) -> None:
|
||||
d: defer.Deferred[float] = defer.Deferred()
|
||||
with context.PreserveLoggingContext():
|
||||
self._reactor.callLater(seconds, d.callback, seconds)
|
||||
res = yield d
|
||||
return res
|
||||
await d
|
||||
|
||||
def time(self) -> float:
|
||||
"""Returns the current system time in seconds since epoch."""
|
||||
|
||||
@ -90,7 +90,7 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
|
||||
) -> Generator["defer.Deferred[Any]", object, None]:
|
||||
@defer.inlineCallbacks
|
||||
def cb() -> Generator["defer.Deferred[object]", object, Tuple[int, JsonDict]]:
|
||||
yield Clock(reactor).sleep(0)
|
||||
yield defer.ensureDeferred(Clock(reactor).sleep(0))
|
||||
return 1, {}
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
||||
@ -131,7 +131,7 @@ class ServerNoticesTests(unittest.HomeserverTestCase):
|
||||
break
|
||||
|
||||
# Sleep and try again.
|
||||
self.clock.sleep(0.1)
|
||||
self.get_success(self.clock.sleep(0.1))
|
||||
else:
|
||||
self.fail(
|
||||
f"Failed to join the server notices room. No 'join' field in sync_body['rooms']: {sync_body['rooms']}"
|
||||
|
||||
@ -65,8 +65,8 @@ ORIGIN_SERVER_TS = 0
|
||||
|
||||
|
||||
class FakeClock:
|
||||
def sleep(self, msec: float) -> "defer.Deferred[None]":
|
||||
return defer.succeed(None)
|
||||
async def sleep(self, msec: float) -> None:
|
||||
return None
|
||||
|
||||
|
||||
class FakeEvent:
|
||||
|
||||
@ -51,20 +51,18 @@ class LoggingContextTestCase(unittest.TestCase):
|
||||
with LoggingContext("test"):
|
||||
self._check_test_key("test")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_sleep(self) -> Generator["defer.Deferred[object]", object, None]:
|
||||
async def test_sleep(self) -> None:
|
||||
clock = Clock(reactor)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def competing_callback() -> Generator["defer.Deferred[object]", object, None]:
|
||||
async def competing_callback() -> None:
|
||||
with LoggingContext("competing"):
|
||||
yield clock.sleep(0)
|
||||
await clock.sleep(0)
|
||||
self._check_test_key("competing")
|
||||
|
||||
reactor.callLater(0, competing_callback)
|
||||
reactor.callLater(0, lambda: defer.ensureDeferred(competing_callback()))
|
||||
|
||||
with LoggingContext("one"):
|
||||
yield clock.sleep(0)
|
||||
await clock.sleep(0)
|
||||
self._check_test_key("one")
|
||||
|
||||
def _test_run_in_background(self, function: Callable[[], object]) -> defer.Deferred:
|
||||
@ -108,9 +106,8 @@ class LoggingContextTestCase(unittest.TestCase):
|
||||
return d2
|
||||
|
||||
def test_run_in_background_with_blocking_fn(self) -> defer.Deferred:
|
||||
@defer.inlineCallbacks
|
||||
def blocking_function() -> Generator["defer.Deferred[object]", object, None]:
|
||||
yield Clock(reactor).sleep(0)
|
||||
async def blocking_function() -> None:
|
||||
await Clock(reactor).sleep(0)
|
||||
|
||||
return self._test_run_in_background(blocking_function)
|
||||
|
||||
@ -133,7 +130,7 @@ class LoggingContextTestCase(unittest.TestCase):
|
||||
def test_run_in_background_with_coroutine(self) -> defer.Deferred:
|
||||
async def testfunc() -> None:
|
||||
self._check_test_key("one")
|
||||
d = Clock(reactor).sleep(0)
|
||||
d = defer.ensureDeferred(Clock(reactor).sleep(0))
|
||||
self.assertIs(current_context(), SENTINEL_CONTEXT)
|
||||
await d
|
||||
self._check_test_key("one")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user