diff --git a/changelog.d/18602.misc b/changelog.d/18602.misc new file mode 100644 index 000000000..637d84268 --- /dev/null +++ b/changelog.d/18602.misc @@ -0,0 +1 @@ +Speed up bulk device deletion. diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index de5b38cac..9e3e70ec1 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -76,7 +76,7 @@ from synapse.storage.databases.main.registration import ( LoginTokenLookupResult, LoginTokenReused, ) -from synapse.types import JsonDict, Requester, UserID +from synapse.types import JsonDict, Requester, StrCollection, UserID from synapse.util import stringutils as stringutils from synapse.util.async_helpers import delay_cancellation, maybe_awaitable from synapse.util.msisdn import phone_number_to_msisdn @@ -1547,6 +1547,31 @@ class AuthHandler: user_id, (token_id for _, token_id, _ in tokens_and_devices) ) + async def delete_access_tokens_for_devices( + self, + user_id: str, + device_ids: StrCollection, + ) -> None: + """Invalidate access tokens for the devices + + Args: + user_id: ID of user the tokens belong to + device_ids: ID of device the tokens are associated with. + If None, tokens associated with any device (or no device) will + be deleted + """ + tokens_and_devices = await self.store.user_delete_access_tokens_for_devices( + user_id, + device_ids, + ) + + # see if any modules want to know about this + if self.password_auth_provider.on_logged_out_callbacks: + for token, _, device_id in tokens_and_devices: + await self.password_auth_provider.on_logged_out( + user_id=user_id, device_id=device_id, access_token=token + ) + async def add_threepid( self, user_id: str, medium: str, address: str, validated_at: int ) -> None: diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 8f9bf92fd..c6e44dae6 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -671,12 +671,12 @@ class DeviceHandler(DeviceWorkerHandler): except_device_id: optional device id which should not be deleted """ device_map = await self.store.get_devices_by_user(user_id) - device_ids = list(device_map) if except_device_id is not None: - device_ids = [d for d in device_ids if d != except_device_id] - await self.delete_devices(user_id, device_ids) + device_map.pop(except_device_id, None) + user_device_ids = device_map.keys() + await self.delete_devices(user_id, user_device_ids) - async def delete_devices(self, user_id: str, device_ids: List[str]) -> None: + async def delete_devices(self, user_id: str, device_ids: StrCollection) -> None: """Delete several devices Args: @@ -695,17 +695,10 @@ class DeviceHandler(DeviceWorkerHandler): else: raise - # Delete data specific to each device. Not optimised as it is not - # considered as part of a critical path. - for device_id in device_ids: - await self._auth_handler.delete_access_tokens_for_user( - user_id, device_id=device_id - ) - await self.store.delete_e2e_keys_by_device( - user_id=user_id, device_id=device_id - ) - - if self.hs.config.experimental.msc3890_enabled: + # Delete data specific to each device. Not optimised as its an + # experimental MSC. + if self.hs.config.experimental.msc3890_enabled: + for device_id in device_ids: # Remove any local notification settings for this device in accordance # with MSC3890. await self._account_data_handler.remove_account_data_for_user( @@ -713,6 +706,13 @@ class DeviceHandler(DeviceWorkerHandler): f"org.matrix.msc3890.local_notification_settings.{device_id}", ) + # If we're deleting a lot of devices, a bunch of them may not have any + # to-device messages queued up. We filter those out to avoid scheduling + # unnecessary tasks. + devices_with_messages = await self.store.get_devices_with_messages( + user_id, device_ids + ) + for device_id in devices_with_messages: # Delete device messages asynchronously and in batches using the task scheduler # We specify an upper stream id to avoid deleting non delivered messages # if an user re-uses a device ID. @@ -726,6 +726,10 @@ class DeviceHandler(DeviceWorkerHandler): }, ) + await self._auth_handler.delete_access_tokens_for_devices( + user_id, device_ids=device_ids + ) + # Pushers are deleted after `delete_access_tokens_for_user` is called so that # modules using `on_logged_out` hook can use them if needed. await self.hs.get_pusherpool().remove_pushers_by_devices(user_id, device_ids) @@ -819,10 +823,11 @@ class DeviceHandler(DeviceWorkerHandler): # This should only happen if there are no updates, so we bail. return - for device_id in device_ids: - logger.debug( - "Notifying about update %r/%r, ID: %r", user_id, device_id, position - ) + if logger.isEnabledFor(logging.DEBUG): + for device_id in device_ids: + logger.debug( + "Notifying about update %r/%r, ID: %r", user_id, device_id, position + ) # specify the user ID too since the user should always get their own device list # updates, even if they aren't in any rooms. @@ -922,9 +927,6 @@ class DeviceHandler(DeviceWorkerHandler): # can't call self.delete_device because that will clobber the # access token so call the storage layer directly await self.store.delete_devices(user_id, [old_device_id]) - await self.store.delete_e2e_keys_by_device( - user_id=user_id, device_id=old_device_id - ) # tell everyone that the old device is gone and that the dehydrated # device has a new display name @@ -946,7 +948,6 @@ class DeviceHandler(DeviceWorkerHandler): raise errors.NotFoundError() await self.delete_devices(user_id, [device_id]) - await self.store.delete_e2e_keys_by_device(user_id=user_id, device_id=device_id) @wrap_as_background_process("_handle_new_device_update_async") async def _handle_new_device_update_async(self) -> None: diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py index 579d29ac1..a22eab247 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -52,10 +52,11 @@ from synapse.storage.database import ( make_in_list_sql_clause, ) from synapse.storage.util.id_generators import MultiWriterIdGenerator -from synapse.types import JsonDict +from synapse.types import JsonDict, StrCollection from synapse.util import Duration, json_encoder from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.stream_change_cache import StreamChangeCache +from synapse.util.iterutils import batch_iter from synapse.util.stringutils import parse_and_validate_server_name if TYPE_CHECKING: @@ -1027,6 +1028,40 @@ class DeviceInboxWorkerStore(SQLBaseStore): # loop first time we run this. self._clock.sleep(1) + async def get_devices_with_messages( + self, user_id: str, device_ids: StrCollection + ) -> StrCollection: + """Get the matching device IDs that have messages in the device inbox.""" + + def get_devices_with_messages_txn( + txn: LoggingTransaction, + batch_device_ids: StrCollection, + ) -> StrCollection: + clause, args = make_in_list_sql_clause( + self.database_engine, "device_id", batch_device_ids + ) + sql = f""" + SELECT DISTINCT device_id FROM device_inbox + WHERE {clause} AND user_id = ? + """ + args.append(user_id) + txn.execute(sql, args) + return {row[0] for row in txn} + + results: Set[str] = set() + for batch_device_ids in batch_iter(device_ids, 1000): + batch_results = await self.db_pool.runInteraction( + "get_devices_with_messages", + get_devices_with_messages_txn, + batch_device_ids, + # We don't need to run in a transaction as it's a single query + db_autocommit=True, + ) + + results.update(batch_results) + + return results + class DeviceInboxBackgroundUpdateStore(SQLBaseStore): DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop" diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 6191f22cd..941d278e6 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -282,7 +282,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): "count_devices_by_users", count_devices_by_users_txn, user_ids ) - @cached() + @cached(tree=True) async def get_device( self, user_id: str, device_id: str ) -> Optional[Mapping[str, Any]]: @@ -1861,7 +1861,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): ) raise StoreError(500, "Problem storing device.") - async def delete_devices(self, user_id: str, device_ids: List[str]) -> None: + async def delete_devices(self, user_id: str, device_ids: StrCollection) -> None: """Deletes several devices. Args: @@ -1885,11 +1885,49 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): values=device_ids, keyvalues={"user_id": user_id}, ) - self._invalidate_cache_and_stream_bulk( - txn, self.get_device, [(user_id, device_id) for device_id in device_ids] + + # Also delete associated e2e keys. + self.db_pool.simple_delete_many_txn( + txn, + table="e2e_device_keys_json", + keyvalues={"user_id": user_id}, + column="device_id", + values=device_ids, + ) + self.db_pool.simple_delete_many_txn( + txn, + table="e2e_one_time_keys_json", + keyvalues={"user_id": user_id}, + column="device_id", + values=device_ids, + ) + self.db_pool.simple_delete_many_txn( + txn, + table="dehydrated_devices", + keyvalues={"user_id": user_id}, + column="device_id", + values=device_ids, + ) + self.db_pool.simple_delete_many_txn( + txn, + table="e2e_fallback_keys_json", + keyvalues={"user_id": user_id}, + column="device_id", + values=device_ids, ) - for batch in batch_iter(device_ids, 100): + # We're bulk deleting potentially many devices at once, so + # let's not invalidate the cache for each device individually. + # Instead, we will invalidate the cache for the user as a whole. + self._invalidate_cache_and_stream(txn, self.get_device, (user_id,)) + self._invalidate_cache_and_stream( + txn, self.count_e2e_one_time_keys, (user_id,) + ) + self._invalidate_cache_and_stream( + txn, self.get_e2e_unused_fallback_key_types, (user_id,) + ) + + for batch in batch_iter(device_ids, 1000): await self.db_pool.runInteraction( "delete_devices", _delete_devices_txn, batch ) @@ -2061,32 +2099,36 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): context = get_active_span_text_map() def add_device_changes_txn( - txn: LoggingTransaction, stream_ids: List[int] + txn: LoggingTransaction, + batch_device_ids: StrCollection, + stream_ids: List[int], ) -> None: self._add_device_change_to_stream_txn( txn, user_id, - device_ids, + batch_device_ids, stream_ids, ) self._add_device_outbound_room_poke_txn( txn, user_id, - device_ids, + batch_device_ids, room_ids, stream_ids, context, ) - async with self._device_list_id_gen.get_next_mult( - len(device_ids) - ) as stream_ids: - await self.db_pool.runInteraction( - "add_device_change_to_stream", - add_device_changes_txn, - stream_ids, - ) + for batch_device_ids in batch_iter(device_ids, 1000): + async with self._device_list_id_gen.get_next_mult( + len(device_ids) + ) as stream_ids: + await self.db_pool.runInteraction( + "add_device_change_to_stream", + add_device_changes_txn, + batch_device_ids, + stream_ids, + ) return stream_ids[-1] diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 341e7014d..0700b0087 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -593,7 +593,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker txn, self.count_e2e_one_time_keys, (user_id, device_id) ) - @cached(max_entries=10000) + @cached(max_entries=10000, tree=True) async def count_e2e_one_time_keys( self, user_id: str, device_id: str ) -> Mapping[str, int]: @@ -808,7 +808,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker }, ) - @cached(max_entries=10000) + @cached(max_entries=10000, tree=True) async def get_e2e_unused_fallback_key_types( self, user_id: str, device_id: str ) -> Sequence[str]: @@ -1632,46 +1632,6 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): log_kv({"message": "Device keys stored."}) return True - async def delete_e2e_keys_by_device(self, user_id: str, device_id: str) -> None: - def delete_e2e_keys_by_device_txn(txn: LoggingTransaction) -> None: - log_kv( - { - "message": "Deleting keys for device", - "device_id": device_id, - "user_id": user_id, - } - ) - self.db_pool.simple_delete_txn( - txn, - table="e2e_device_keys_json", - keyvalues={"user_id": user_id, "device_id": device_id}, - ) - self.db_pool.simple_delete_txn( - txn, - table="e2e_one_time_keys_json", - keyvalues={"user_id": user_id, "device_id": device_id}, - ) - self._invalidate_cache_and_stream( - txn, self.count_e2e_one_time_keys, (user_id, device_id) - ) - self.db_pool.simple_delete_txn( - txn, - table="dehydrated_devices", - keyvalues={"user_id": user_id, "device_id": device_id}, - ) - self.db_pool.simple_delete_txn( - txn, - table="e2e_fallback_keys_json", - keyvalues={"user_id": user_id, "device_id": device_id}, - ) - self._invalidate_cache_and_stream( - txn, self.get_e2e_unused_fallback_key_types, (user_id, device_id) - ) - - await self.db_pool.runInteraction( - "delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn - ) - def _set_e2e_cross_signing_key_txn( self, txn: LoggingTransaction, diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 40c551bcb..1e21996b1 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -40,14 +40,16 @@ from synapse.storage.database import ( DatabasePool, LoggingDatabaseConnection, LoggingTransaction, + make_in_list_sql_clause, ) from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.databases.main.stats import StatsStore from synapse.storage.types import Cursor from synapse.storage.util.id_generators import IdGenerator from synapse.storage.util.sequence import build_sequence_generator -from synapse.types import JsonDict, UserID, UserInfo +from synapse.types import JsonDict, StrCollection, UserID, UserInfo from synapse.util.caches.descriptors import cached +from synapse.util.iterutils import batch_iter if TYPE_CHECKING: from synapse.server import HomeServer @@ -2801,6 +2803,81 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): return await self.db_pool.runInteraction("user_delete_access_tokens", f) + async def user_delete_access_tokens_for_devices( + self, + user_id: str, + device_ids: StrCollection, + ) -> List[Tuple[str, int, Optional[str]]]: + """ + Invalidate access and refresh tokens belonging to a user + + Args: + user_id: ID of user the tokens belong to + device_ids: The devices to delete tokens for. + Returns: + A tuple of (token, token id, device id) for each of the deleted tokens + """ + + def user_delete_access_tokens_for_devices_txn( + txn: LoggingTransaction, batch_device_ids: StrCollection + ) -> List[Tuple[str, int, Optional[str]]]: + self.db_pool.simple_delete_many_txn( + txn, + table="refresh_tokens", + keyvalues={"user_id": user_id}, + column="device_id", + values=batch_device_ids, + ) + + clause, args = make_in_list_sql_clause( + txn.database_engine, "device_id", batch_device_ids + ) + args.append(user_id) + + if self.database_engine.supports_returning: + sql = f""" + DELETE FROM access_tokens + WHERE {clause} AND user_id = ? + RETURNING token, id, device_id + """ + txn.execute(sql, args) + tokens_and_devices = txn.fetchall() + else: + tokens_and_devices = self.db_pool.simple_select_many_txn( + txn, + table="access_tokens", + column="device_id", + iterable=batch_device_ids, + keyvalues={"user_id": user_id}, + retcols=("token", "id", "device_id"), + ) + + self.db_pool.simple_delete_many_txn( + txn, + table="access_tokens", + keyvalues={"user_id": user_id}, + column="device_id", + values=batch_device_ids, + ) + + self._invalidate_cache_and_stream_bulk( + txn, + self.get_user_by_access_token, + [(t[0],) for t in tokens_and_devices], + ) + return tokens_and_devices + + results = [] + for batch_device_ids in batch_iter(device_ids, 1000): + tokens_and_devices = await self.db_pool.runInteraction( + "user_delete_access_tokens_for_devices", + user_delete_access_tokens_for_devices_txn, + batch_device_ids, + ) + results.extend(tokens_and_devices) + + return results + async def delete_access_token(self, access_token: str) -> None: def f(txn: LoggingTransaction) -> None: self.db_pool.simple_delete_one_txn(