diff --git a/changelog.d/18992.misc b/changelog.d/18992.misc new file mode 100644 index 000000000..ba4470bff --- /dev/null +++ b/changelog.d/18992.misc @@ -0,0 +1 @@ +Remove `MockClock()` in tests. diff --git a/synapse/app/phone_stats_home.py b/synapse/app/phone_stats_home.py index 69d3ac78f..7b8e7fe70 100644 --- a/synapse/app/phone_stats_home.py +++ b/synapse/app/phone_stats_home.py @@ -33,15 +33,17 @@ from synapse.metrics.background_process_metrics import ( run_as_background_process, ) from synapse.types import JsonDict -from synapse.util.constants import ONE_HOUR_SECONDS, ONE_MINUTE_SECONDS +from synapse.util.constants import ( + MILLISECONDS_PER_SECOND, + ONE_HOUR_SECONDS, + ONE_MINUTE_SECONDS, +) if TYPE_CHECKING: from synapse.server import HomeServer logger = logging.getLogger("synapse.app.homeserver") -MILLISECONDS_PER_SECOND = 1000 - INITIAL_DELAY_BEFORE_FIRST_PHONE_HOME_SECONDS = 5 * ONE_MINUTE_SECONDS """ We wait 5 minutes to send the first set of stats as the server can be quite busy the diff --git a/synapse/util/constants.py b/synapse/util/constants.py index 998601714..7a3d073df 100644 --- a/synapse/util/constants.py +++ b/synapse/util/constants.py @@ -18,3 +18,5 @@ # readability and catching bugs. ONE_MINUTE_SECONDS = 60 ONE_HOUR_SECONDS = 60 * ONE_MINUTE_SECONDS + +MILLISECONDS_PER_SECOND = 1000 diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py index 9498ea127..0385190f3 100644 --- a/tests/appservice/test_scheduler.py +++ b/tests/appservice/test_scheduler.py @@ -18,7 +18,7 @@ # [This file includes modifications made by New Vector Limited] # # -from typing import List, Optional, Sequence, Tuple, cast +from typing import List, Optional, Sequence, Tuple from unittest.mock import AsyncMock, Mock from typing_extensions import TypeAlias @@ -44,13 +44,12 @@ from synapse.types import DeviceListUpdates, JsonDict from synapse.util.clock import Clock from tests import unittest - -from ..utils import MockClock +from tests.server import get_clock class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): def setUp(self) -> None: - self.clock = MockClock() + self.reactor, self.clock = get_clock() self.store = Mock() self.as_api = Mock() @@ -170,14 +169,14 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase): def setUp(self) -> None: - self.clock = MockClock() + self.reactor, self.clock = get_clock() self.as_api = Mock() self.store = Mock() self.service = Mock() self.callback = AsyncMock() self.recoverer = _Recoverer( server_name="test_server", - clock=cast(Clock, self.clock), + clock=self.clock, as_api=self.as_api, store=self.store, service=self.service, @@ -202,7 +201,7 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase): txn.send = AsyncMock(return_value=True) txn.complete = AsyncMock(return_value=None) # wait for exp backoff - self.clock.advance_time(2) + self.reactor.advance(2) self.assertEqual(1, txn.send.call_count) self.assertEqual(1, txn.complete.call_count) # 2 because it needs to get None to know there are no more txns @@ -229,21 +228,21 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase): self.assertEqual(0, self.store.get_oldest_unsent_txn.call_count) txn.send = AsyncMock(return_value=False) txn.complete = AsyncMock(return_value=None) - self.clock.advance_time(2) + self.reactor.advance(2) self.assertEqual(1, txn.send.call_count) self.assertEqual(0, txn.complete.call_count) self.assertEqual(0, self.callback.call_count) - self.clock.advance_time(4) + self.reactor.advance(4) self.assertEqual(2, txn.send.call_count) self.assertEqual(0, txn.complete.call_count) self.assertEqual(0, self.callback.call_count) - self.clock.advance_time(8) + self.reactor.advance(8) self.assertEqual(3, txn.send.call_count) self.assertEqual(0, txn.complete.call_count) self.assertEqual(0, self.callback.call_count) txn.send = AsyncMock(return_value=True) # successfully send the txn pop_txn = True # returns the txn the first time, then no more. - self.clock.advance_time(16) + self.reactor.advance(16) self.assertEqual(1, txn.send.call_count) # new mock reset call count self.assertEqual(1, txn.complete.call_count) self.callback.assert_called_once_with(self.recoverer) @@ -268,7 +267,7 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase): self.assertEqual(0, self.store.get_oldest_unsent_txn.call_count) txn.send = AsyncMock(return_value=False) txn.complete = AsyncMock(return_value=None) - self.clock.advance_time(2) + self.reactor.advance(2) self.assertEqual(1, txn.send.call_count) self.assertEqual(0, txn.complete.call_count) self.assertEqual(0, self.callback.call_count) diff --git a/tests/config/test_oauth_delegation.py b/tests/config/test_oauth_delegation.py index 833cfe628..85e0a3b6b 100644 --- a/tests/config/test_oauth_delegation.py +++ b/tests/config/test_oauth_delegation.py @@ -231,7 +231,10 @@ class MSC3861OAuthDelegation(TestCase): reactor, clock = get_clock() with self.assertRaises(ConfigError): setup_test_homeserver( - self.addCleanup, reactor=reactor, clock=clock, config=config + cleanup_func=self.addCleanup, + config=config, + reactor=reactor, + clock=clock, ) def test_jwt_auth_cannot_be_enabled(self) -> None: @@ -395,7 +398,10 @@ class MasAuthDelegation(TestCase): reactor, clock = get_clock() with self.assertRaises(ConfigError): setup_test_homeserver( - self.addCleanup, reactor=reactor, clock=clock, config=config + cleanup_func=self.addCleanup, + config=config, + reactor=reactor, + clock=clock, ) @skip_unless(HAS_AUTHLIB, "requires authlib") diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index 999d7f5e6..6516b7db1 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -49,9 +49,9 @@ from synapse.util.clock import Clock from synapse.util.stringutils import random_string from tests import unittest +from tests.server import get_clock from tests.test_utils import event_injection from tests.unittest import override_config -from tests.utils import MockClock class AppServiceHandlerTestCase(unittest.TestCase): @@ -61,6 +61,8 @@ class AppServiceHandlerTestCase(unittest.TestCase): self.mock_store = Mock() self.mock_as_api = AsyncMock() self.mock_scheduler = Mock() + self.reactor, self.clock = get_clock() + hs = Mock() hs.get_datastores.return_value = Mock(main=self.mock_store) self.mock_store.get_appservice_last_pos = AsyncMock(return_value=None) @@ -68,7 +70,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): self.mock_store.set_appservice_stream_type_pos = AsyncMock(return_value=None) hs.get_application_service_api.return_value = self.mock_as_api hs.get_application_service_scheduler.return_value = self.mock_scheduler - hs.get_clock.return_value = MockClock() + hs.get_clock.return_value = self.clock self.handler = ApplicationServicesHandler(hs) self.event_source = hs.get_event_sources() diff --git a/tests/handlers/test_e2e_room_keys.py b/tests/handlers/test_e2e_room_keys.py index 910c24c16..5085a0309 100644 --- a/tests/handlers/test_e2e_room_keys.py +++ b/tests/handlers/test_e2e_room_keys.py @@ -21,7 +21,6 @@ # import copy -from unittest import mock from twisted.internet.testing import MemoryReactor @@ -50,7 +49,7 @@ room_keys = { class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - return self.setup_test_homeserver(replication_layer=mock.Mock()) + return self.setup_test_homeserver() def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.handler = hs.get_e2e_room_keys_handler() diff --git a/tests/http/federation/test_srv_resolver.py b/tests/http/federation/test_srv_resolver.py index 4fd0fb922..a359b0a14 100644 --- a/tests/http/federation/test_srv_resolver.py +++ b/tests/http/federation/test_srv_resolver.py @@ -30,7 +30,7 @@ from synapse.http.federation.srv_resolver import Server, SrvResolver from synapse.logging.context import LoggingContext, current_context from tests import unittest -from tests.utils import MockClock +from tests.server import get_clock class SrvResolverTestCase(unittest.TestCase): @@ -105,7 +105,7 @@ class SrvResolverTestCase(unittest.TestCase): @defer.inlineCallbacks def test_from_cache(self) -> Generator["Deferred[object]", object, None]: - clock = MockClock() + reactor, clock = get_clock() dns_client_mock = Mock(spec_set=["lookupService"]) dns_client_mock.lookupService = Mock(spec_set=[]) diff --git a/tests/http/test_matrixfederationclient.py b/tests/http/test_matrixfederationclient.py index d06ea8c3a..6d8754188 100644 --- a/tests/http/test_matrixfederationclient.py +++ b/tests/http/test_matrixfederationclient.py @@ -63,10 +63,6 @@ def check_logcontext(context: LoggingContextOrSentinel) -> None: class FederationClientTests(HomeserverTestCase): - def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - hs = self.setup_test_homeserver(reactor=reactor, clock=clock) - return hs - def prepare( self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer ) -> None: diff --git a/tests/media/test_media_retention.py b/tests/media/test_media_retention.py index aec1adb04..6dba21451 100644 --- a/tests/media/test_media_retention.py +++ b/tests/media/test_media_retention.py @@ -37,7 +37,6 @@ from synapse.util.stringutils import ( from tests import unittest from tests.unittest import override_config -from tests.utils import MockClock class MediaRetentionTestCase(unittest.HomeserverTestCase): @@ -51,12 +50,6 @@ class MediaRetentionTestCase(unittest.HomeserverTestCase): admin.register_servlets_for_client_rest_resource, ] - def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - # We need to be able to test advancing time in the homeserver, so we - # replace the test homeserver's default clock with a MockClock, which - # supports advancing time. - return self.setup_test_homeserver(clock=MockClock()) - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.remote_server_name = "remote.homeserver" self.store = hs.get_datastores().main diff --git a/tests/rest/client/test_transactions.py b/tests/rest/client/test_transactions.py index 9c9eca541..c22c1a661 100644 --- a/tests/rest/client/test_transactions.py +++ b/tests/rest/client/test_transactions.py @@ -29,16 +29,19 @@ from synapse.logging.context import SENTINEL_CONTEXT, LoggingContext, current_co from synapse.rest.client.transactions import CLEANUP_PERIOD_MS, HttpTransactionCache from synapse.types import ISynapseReactor, JsonDict from synapse.util.clock import Clock +from synapse.util.constants import ( + MILLISECONDS_PER_SECOND, +) from tests import unittest -from tests.utils import MockClock +from tests.server import get_clock reactor = cast(ISynapseReactor, _reactor) class HttpTransactionCacheTestCase(unittest.TestCase): def setUp(self) -> None: - self.clock = MockClock() + self.reactor, self.clock = get_clock() self.hs = Mock() self.hs.get_clock = Mock(return_value=self.clock) self.hs.get_auth = Mock() @@ -180,8 +183,9 @@ class HttpTransactionCacheTestCase(unittest.TestCase): yield self.cache.fetch_or_execute_request( self.mock_request, self.mock_requester, cb, "an arg" ) - # should NOT have cleaned up yet - self.clock.advance_time_msec(CLEANUP_PERIOD_MS / 2) + # Advance time just under the cleanup period. + # Should NOT have cleaned up yet + self.reactor.advance((CLEANUP_PERIOD_MS - 1) / MILLISECONDS_PER_SECOND) yield self.cache.fetch_or_execute_request( self.mock_request, self.mock_requester, cb, "an arg" @@ -189,7 +193,8 @@ class HttpTransactionCacheTestCase(unittest.TestCase): # still using cache cb.assert_called_once_with("an arg") - self.clock.advance_time_msec(CLEANUP_PERIOD_MS) + # Advance time just after the cleanup period. + self.reactor.advance(2 / MILLISECONDS_PER_SECOND) yield self.cache.fetch_or_execute_request( self.mock_request, self.mock_requester, cb, "an arg" diff --git a/tests/rest/key/v2/test_remote_key_resource.py b/tests/rest/key/v2/test_remote_key_resource.py index cf8241438..8d2489f71 100644 --- a/tests/rest/key/v2/test_remote_key_resource.py +++ b/tests/rest/key/v2/test_remote_key_resource.py @@ -170,7 +170,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: # make a second homeserver, configured to use the first one as a key notary self.http_client2 = Mock() - config = default_config(name="keyclient") + config = default_config(server_name="keyclient") config["trusted_key_servers"] = [ { "server_name": self.hs.hostname, diff --git a/tests/server.py b/tests/server.py index 858b41d56..226bdf4bb 100644 --- a/tests/server.py +++ b/tests/server.py @@ -114,7 +114,6 @@ from tests.utils import ( POSTGRES_USER, SQLITE_PERSIST_DB, USE_POSTGRES_FOR_TESTS, - MockClock, default_config, ) @@ -786,9 +785,9 @@ class ThreadPool: def get_clock() -> Tuple[ThreadedMemoryReactorClock, Clock]: - clock = ThreadedMemoryReactorClock() - hs_clock = Clock(clock, server_name="test_server") - return clock, hs_clock + reactor = ThreadedMemoryReactorClock() + hs_clock = Clock(reactor, server_name="test_server") + return reactor, hs_clock @implementer(ITCPTransport) @@ -1020,12 +1019,14 @@ class TestHomeServer(HomeServer): def setup_test_homeserver( + *, cleanup_func: Callable[[Callable[[], None]], None], - name: str = "test", + server_name: str = "test", config: Optional[HomeServerConfig] = None, reactor: Optional[ISynapseReactor] = None, homeserver_to_use: Type[HomeServer] = TestHomeServer, - **kwargs: Any, + db_txn_limit: Optional[int] = None, + **extra_homeserver_attributes: Any, ) -> HomeServer: """ Setup a homeserver suitable for running tests against. Keyword arguments @@ -1034,29 +1035,41 @@ def setup_test_homeserver( If no datastore is supplied, one is created and given to the homeserver. Args: - cleanup_func : The function used to register a cleanup routine for - after the test. + cleanup_func: The function used to register a cleanup routine for after the + test. + server_name: Homeserver name + config: Homeserver config + reactor: Twisted reactor + homeserver_to_use: Homeserver class to instantiate. + db_txn_limit: Gives the maximum number of database transactions to run per + connection before reconnecting. 0 means no limit. If unset, defaults to None + here which will default upstream to `0`. + **extra_homeserver_attributes: Additional keyword arguments to install as + `@cache_in_self` attributes on the homeserver. For example, `clock` will be + installed as `hs._clock`. Calling this method directly is deprecated: you should instead derive from HomeserverTestCase. """ if reactor is None: - from twisted.internet import reactor as _reactor - - reactor = cast(ISynapseReactor, _reactor) + reactor = ThreadedMemoryReactorClock() if config is None: - config = default_config(name, parse=True) + config = default_config(server_name, parse=True) + + server_name = config.server.server_name + if not isinstance(server_name, str): + raise ConfigError("Must be a string", ("server_name",)) + + if "clock" not in extra_homeserver_attributes: + extra_homeserver_attributes["clock"] = Clock(reactor, server_name=server_name) config.caches.resize_all_caches() - if "clock" not in kwargs: - kwargs["clock"] = MockClock() - if USE_POSTGRES_FOR_TESTS: test_db = "synapse_test_%s" % uuid.uuid4().hex - database_config = { + database_config: JsonDict = { "name": "psycopg2", "args": { "dbname": test_db, @@ -1088,10 +1101,6 @@ def setup_test_homeserver( "args": {"database": test_db_location, "cp_min": 1, "cp_max": 1}, } - server_name = config.server.server_name - if not isinstance(server_name, str): - raise ConfigError("Must be a string", ("server_name",)) - # Check if we have set up a DB that we can use as a template. global PREPPED_SQLITE_DB_CONN if PREPPED_SQLITE_DB_CONN is None: @@ -1111,8 +1120,8 @@ def setup_test_homeserver( database_config["_TEST_PREPPED_CONN"] = PREPPED_SQLITE_DB_CONN - if "db_txn_limit" in kwargs: - database_config["txn_limit"] = kwargs["db_txn_limit"] + if db_txn_limit is not None: + database_config["txn_limit"] = db_txn_limit database = DatabaseConnectionConfig("master", database_config) config.database.databases = [database] @@ -1139,7 +1148,7 @@ def setup_test_homeserver( db_conn.close() hs = homeserver_to_use( - name, + server_name, config=config, version_string="Synapse/tests", reactor=reactor, @@ -1149,7 +1158,7 @@ def setup_test_homeserver( cleanup_func(hs.cleanup) # Install @cache_in_self attributes - for key, val in kwargs.items(): + for key, val in extra_homeserver_attributes.items(): setattr(hs, "_" + key, val) # Mock TLS diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index 11313fc93..577229c11 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -86,7 +86,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): conn_pool.runWithConnection = runWithConnection - config = default_config(name="test", parse=True) + config = default_config(server_name="test", parse=True) hs = TestHomeServer("test", config=config) if USE_POSTGRES_FOR_TESTS: diff --git a/tests/test_server.py b/tests/test_server.py index 69efceafe..66c5cf9e3 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -55,9 +55,9 @@ class JsonResourceTests(unittest.TestCase): reactor, clock = get_clock() self.reactor = reactor self.homeserver = setup_test_homeserver( - self.addCleanup, - clock=clock, + cleanup_func=self.addCleanup, reactor=self.reactor, + clock=clock, ) def test_handler_for_request(self) -> None: @@ -217,9 +217,9 @@ class OptionsResourceTests(unittest.TestCase): reactor, clock = get_clock() self.reactor = reactor self.homeserver = setup_test_homeserver( - self.addCleanup, - clock=clock, + cleanup_func=self.addCleanup, reactor=self.reactor, + clock=clock, ) class DummyResource(Resource): diff --git a/tests/test_state.py b/tests/test_state.py index 16446c16b..ab7b52e90 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -29,7 +29,6 @@ from typing import ( Optional, Set, Tuple, - cast, ) from unittest.mock import AsyncMock, Mock @@ -43,12 +42,11 @@ from synapse.events.snapshot import EventContext from synapse.state import StateHandler, StateResolutionHandler, _make_state_cache_entry from synapse.types import MutableStateMap, StateMap from synapse.types.state import StateFilter -from synapse.util.clock import Clock from synapse.util.macaroons import MacaroonGenerator from tests import unittest - -from .utils import MockClock, default_config +from tests.server import get_clock +from tests.utils import default_config _next_event_id = 1000 @@ -248,7 +246,7 @@ class StateTestCase(unittest.TestCase): "hostname", ] ) - clock = cast(Clock, MockClock()) + reactor, clock = get_clock() hs.config = default_config("tesths", True) hs.get_datastores.return_value = Mock( main=self.dummy_store, diff --git a/tests/test_test_utils.py b/tests/test_test_utils.py deleted file mode 100644 index c52f963a7..000000000 --- a/tests/test_test_utils.py +++ /dev/null @@ -1,79 +0,0 @@ -# -# This file is licensed under the Affero General Public License (AGPL) version 3. -# -# Copyright 2014-2016 OpenMarket Ltd -# Copyright (C) 2023 New Vector, Ltd -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Affero General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# See the GNU Affero General Public License for more details: -# . -# -# Originally licensed under the Apache License, Version 2.0: -# . -# -# [This file includes modifications made by New Vector Limited] -# -# - -from tests import unittest -from tests.utils import MockClock - - -class MockClockTestCase(unittest.TestCase): - def setUp(self) -> None: - self.clock = MockClock() - - def test_advance_time(self) -> None: - start_time = self.clock.time() - - self.clock.advance_time(20) - - self.assertEqual(20, self.clock.time() - start_time) - - def test_later(self) -> None: - invoked = [0, 0] - - def _cb0() -> None: - invoked[0] = 1 - - self.clock.call_later(10, _cb0) - - def _cb1() -> None: - invoked[1] = 1 - - self.clock.call_later(20, _cb1) - - self.assertFalse(invoked[0]) - - self.clock.advance_time(15) - - self.assertTrue(invoked[0]) - self.assertFalse(invoked[1]) - - self.clock.advance_time(5) - - self.assertTrue(invoked[1]) - - def test_cancel_later(self) -> None: - invoked = [0, 0] - - def _cb0() -> None: - invoked[0] = 1 - - t0 = self.clock.call_later(10, _cb0) - - def _cb1() -> None: - invoked[1] = 1 - - self.clock.call_later(20, _cb1) - - self.clock.cancel_call_later(t0) - - self.clock.advance_time(30) - - self.assertFalse(invoked[0]) - self.assertTrue(invoked[1]) diff --git a/tests/unittest.py b/tests/unittest.py index 8be4e635a..9ab052e7c 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -80,7 +80,7 @@ from synapse.logging.context import ( from synapse.rest import RegisterServletsFunc from synapse.server import HomeServer from synapse.storage.keys import FetchKeyResult -from synapse.types import JsonDict, Requester, UserID, create_requester +from synapse.types import ISynapseReactor, JsonDict, Requester, UserID, create_requester from synapse.util.clock import Clock from synapse.util.httpresourcetree import create_resource_tree @@ -99,6 +99,8 @@ from tests.utils import checked_cast, default_config, setupdb setupdb() setup_logging() +logger = logging.getLogger(__name__) + TV = TypeVar("TV") _ExcType = TypeVar("_ExcType", bound=BaseException, covariant=True) @@ -135,7 +137,7 @@ def around(target: TV) -> Callable[[Callable[Concatenate[S, P], R]], None]: return _around -_TConfig = TypeVar("_TConfig", Config, RootConfig) +_TConfig = TypeVar("_TConfig", Config, HomeServerConfig) def deepcopy_config(config: _TConfig) -> _TConfig: @@ -161,13 +163,13 @@ def deepcopy_config(config: _TConfig) -> _TConfig: @functools.lru_cache(maxsize=8) -def _parse_config_dict(config: str) -> RootConfig: +def _parse_config_dict(config: str) -> HomeServerConfig: config_obj = HomeServerConfig() config_obj.parse_config_dict(json.loads(config), "", "") return config_obj -def make_homeserver_config_obj(config: Dict[str, Any]) -> RootConfig: +def make_homeserver_config_obj(config: Dict[str, Any]) -> HomeServerConfig: """Creates a :class:`HomeServerConfig` instance with the given configuration dict. This is equivalent to:: @@ -392,8 +394,8 @@ class HomeserverTestCase(TestCase): hijacking the authentication system to return a fixed user, and then calling the prepare function. """ + # We need to share the reactor between the homeserver and all of our test utils. self.reactor, self.clock = get_clock() - self._hs_args = {"clock": self.clock, "reactor": self.reactor} self.hs = self.make_homeserver(self.reactor, self.clock) self.hs.get_datastores().main.tests_allow_no_chain_cover_index = False @@ -511,7 +513,7 @@ class HomeserverTestCase(TestCase): Function to be overridden in subclasses. """ - hs = self.setup_test_homeserver() + hs = self.setup_test_homeserver(reactor=reactor, clock=clock) return hs def create_test_resource(self) -> Resource: @@ -634,7 +636,12 @@ class HomeserverTestCase(TestCase): ) def setup_test_homeserver( - self, server_name: Optional[str] = None, **kwargs: Any + self, + server_name: Optional[str] = None, + config: Optional[JsonDict] = None, + reactor: Optional[ISynapseReactor] = None, + clock: Optional[Clock] = None, + **extra_homeserver_attributes: Any, ) -> HomeServer: """ Set up the test homeserver, meant to be called by the overridable @@ -647,12 +654,15 @@ class HomeserverTestCase(TestCase): Returns: synapse.server.HomeServer """ - kwargs = dict(kwargs) - kwargs.update(self._hs_args) - if "config" not in kwargs: + if config is None: config = self.default_config() - else: - config = kwargs["config"] + + # The sane default is to use the same reactor and clock as our other test utils + if reactor is None: + reactor = self.reactor + + if clock is None: + clock = self.clock # The server name can be specified using either the `name` argument or a config # override. The `name` argument takes precedence over any config overrides. @@ -661,19 +671,24 @@ class HomeserverTestCase(TestCase): # Parse the config from a config dict into a HomeServerConfig config_obj = make_homeserver_config_obj(config) - kwargs["config"] = config_obj # The server name in the config is now `name`, if provided, or the `server_name` # from a config override, or the default of "test". Whichever it is, we # construct a homeserver with a matching name. server_name = config_obj.server.server_name - kwargs["name"] = server_name async def run_bg_updates() -> None: with LoggingContext(name="run_bg_updates", server_name=server_name): self.get_success(stor.db_pool.updates.run_background_updates(False)) - hs = setup_test_homeserver(self.addCleanup, **kwargs) + hs = setup_test_homeserver( + cleanup_func=self.addCleanup, + server_name=server_name, + config=config_obj, + reactor=reactor, + clock=clock, + **extra_homeserver_attributes, + ) stor = hs.get_datastores().main # Run the database background updates, when running against "master". diff --git a/tests/util/test_expiring_cache.py b/tests/util/test_expiring_cache.py index bfcc6cd12..eda2d586f 100644 --- a/tests/util/test_expiring_cache.py +++ b/tests/util/test_expiring_cache.py @@ -19,23 +19,22 @@ # # -from typing import List, cast +from typing import List from synapse.util.caches.expiringcache import ExpiringCache -from synapse.util.clock import Clock -from tests.utils import MockClock +from tests.server import get_clock from .. import unittest class ExpiringCacheTestCase(unittest.HomeserverTestCase): def test_get_set(self) -> None: - clock = MockClock() + reactor, clock = get_clock() cache: ExpiringCache[str, str] = ExpiringCache( cache_name="test", server_name="testserver", - clock=cast(Clock, clock), + clock=clock, max_len=1, ) @@ -44,11 +43,11 @@ class ExpiringCacheTestCase(unittest.HomeserverTestCase): self.assertEqual(cache["key"], "value") def test_eviction(self) -> None: - clock = MockClock() + reactor, clock = get_clock() cache: ExpiringCache[str, str] = ExpiringCache( cache_name="test", server_name="testserver", - clock=cast(Clock, clock), + clock=clock, max_len=2, ) @@ -63,11 +62,11 @@ class ExpiringCacheTestCase(unittest.HomeserverTestCase): self.assertEqual(cache.get("key3"), "value3") def test_iterable_eviction(self) -> None: - clock = MockClock() + reactor, clock = get_clock() cache: ExpiringCache[str, List[int]] = ExpiringCache( cache_name="test", server_name="testserver", - clock=cast(Clock, clock), + clock=clock, max_len=5, iterable=True, ) @@ -87,25 +86,25 @@ class ExpiringCacheTestCase(unittest.HomeserverTestCase): self.assertEqual(cache.get("key4"), [6, 7]) def test_time_eviction(self) -> None: - clock = MockClock() + reactor, clock = get_clock() cache: ExpiringCache[str, int] = ExpiringCache( cache_name="test", server_name="testserver", - clock=cast(Clock, clock), + clock=clock, expiry_ms=1000, ) cache["key"] = 1 - clock.advance_time(0.5) + reactor.advance(0.5) cache["key2"] = 2 self.assertEqual(cache.get("key"), 1) self.assertEqual(cache.get("key2"), 2) - clock.advance_time(0.9) + reactor.advance(0.9) self.assertEqual(cache.get("key"), None) self.assertEqual(cache.get("key2"), 2) - clock.advance_time(1) + reactor.advance(1) self.assertEqual(cache.get("key"), None) self.assertEqual(cache.get("key2"), None) diff --git a/tests/utils.py b/tests/utils.py index d1b66d415..051388ee2 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -24,27 +24,19 @@ import os import signal from types import FrameType, TracebackType from typing import ( - Any, - Callable, Dict, - List, Literal, Optional, - Tuple, Type, TypeVar, Union, overload, ) -import attr -from typing_extensions import ParamSpec - from synapse.api.constants import EventTypes from synapse.api.room_versions import RoomVersions from synapse.config.homeserver import HomeServerConfig from synapse.config.server import DEFAULT_ROOM_VERSION -from synapse.logging.context import current_context, set_current_context from synapse.server import HomeServer from synapse.storage.database import LoggingDatabaseConnection from synapse.storage.engines import create_engine @@ -140,21 +132,27 @@ def setupdb() -> None: @overload -def default_config(name: str, parse: Literal[False] = ...) -> Dict[str, object]: ... +def default_config( + server_name: str, parse: Literal[False] = ... +) -> Dict[str, object]: ... @overload -def default_config(name: str, parse: Literal[True]) -> HomeServerConfig: ... +def default_config(server_name: str, parse: Literal[True]) -> HomeServerConfig: ... def default_config( - name: str, parse: bool = False + server_name: str, parse: bool = False ) -> Union[Dict[str, object], HomeServerConfig]: """ Create a reasonable test config. + + Args: + server_name: homeserver name + parse: TODO """ config_dict = { - "server_name": name, + "server_name": server_name, # Setting this to an empty list turns off federation sending. "federation_sender_instances": [], "media_store_path": "media", @@ -247,101 +245,6 @@ def mock_getRawHeaders(headers=None): # type: ignore[no-untyped-def] return getRawHeaders -P = ParamSpec("P") - - -@attr.s(slots=True, auto_attribs=True) -class Timer: - absolute_time: float - callback: Callable[[], None] - expired: bool - - -# TODO: Make this generic over a ParamSpec? -@attr.s(slots=True, auto_attribs=True) -class Looper: - func: Callable[..., Any] - interval: float # seconds - last: float - args: Tuple[object, ...] - kwargs: Dict[str, object] - - -class MockClock: - now = 1000.0 - - def __init__(self) -> None: - # Timers in no particular order - self.timers: List[Timer] = [] - self.loopers: List[Looper] = [] - - def time(self) -> float: - return self.now - - def time_msec(self) -> int: - return int(self.time() * 1000) - - def call_later( - self, - delay: float, - callback: Callable[P, object], - *args: P.args, - **kwargs: P.kwargs, - ) -> Timer: - ctx = current_context() - - def wrapped_callback() -> None: - set_current_context(ctx) - callback(*args, **kwargs) - - t = Timer(self.now + delay, wrapped_callback, False) - self.timers.append(t) - - return t - - def looping_call( - self, - function: Callable[P, object], - interval: float, - *args: P.args, - **kwargs: P.kwargs, - ) -> None: - self.loopers.append(Looper(function, interval / 1000.0, self.now, args, kwargs)) - - def cancel_call_later(self, timer: Timer, ignore_errs: bool = False) -> None: - if timer.expired: - if not ignore_errs: - raise Exception("Cannot cancel an expired timer") - - timer.expired = True - self.timers = [t for t in self.timers if t != timer] - - # For unit testing - def advance_time(self, secs: float) -> None: - self.now += secs - - timers = self.timers - self.timers = [] - - for t in timers: - if t.expired: - raise Exception("Timer already expired") - - if self.now >= t.absolute_time: - t.expired = True - t.callback() - else: - self.timers.append(t) - - for looped in self.loopers: - if looped.last + looped.interval < self.now: - looped.func(*looped.args, **looped.kwargs) - looped.last = self.now - - def advance_time_msec(self, ms: float) -> None: - self.advance_time(ms / 1000.0) - - async def create_room(hs: HomeServer, room_id: str, creator_id: str) -> None: """Creates and persist a creation event for the given room"""