diff --git a/changelog.d/18899.feature b/changelog.d/18899.feature new file mode 100644 index 000000000..ee7141efc --- /dev/null +++ b/changelog.d/18899.feature @@ -0,0 +1 @@ +Add an in-memory cache to `_get_e2e_cross_signing_signatures_for_devices` to reduce DB load. \ No newline at end of file diff --git a/synapse/storage/database.py b/synapse/storage/database.py index cfec36e0f..aae029f91 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -2653,8 +2653,7 @@ def make_in_list_sql_clause( # These overloads ensure that `columns` and `iterable` values have the same length. -# Suppress "Single overload definition, multiple required" complaint. -@overload # type: ignore[misc] +@overload def make_tuple_in_list_sql_clause( database_engine: BaseDatabaseEngine, columns: Tuple[str, str], @@ -2662,6 +2661,14 @@ def make_tuple_in_list_sql_clause( ) -> Tuple[str, list]: ... +@overload +def make_tuple_in_list_sql_clause( + database_engine: BaseDatabaseEngine, + columns: Tuple[str, str, str], + iterable: Collection[Tuple[Any, Any, Any]], +) -> Tuple[str, list]: ... + + def make_tuple_in_list_sql_clause( database_engine: BaseDatabaseEngine, columns: Tuple[str, ...], diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index 779492681..cad26fefa 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -21,6 +21,7 @@ import itertools +import json import logging from typing import TYPE_CHECKING, Any, Collection, Iterable, List, Optional, Tuple @@ -62,6 +63,12 @@ PURGE_HISTORY_CACHE_NAME = "ph_cache_fake" # As above, but for invalidating room caches on room deletion DELETE_ROOM_CACHE_NAME = "dr_cache_fake" +# This cache takes a list of tuples as its first argument, which requires +# special handling. +GET_E2E_CROSS_SIGNING_SIGNATURES_FOR_DEVICE_CACHE_NAME = ( + "_get_e2e_cross_signing_signatures_for_device" +) + # How long between cache invalidation table cleanups, once we have caught up # with the backlog. REGULAR_CLEANUP_INTERVAL_MS = Config.parse_duration("1h") @@ -270,6 +277,33 @@ class CacheInvalidationWorkerStore(SQLBaseStore): # room membership. # # self._membership_stream_cache.all_entities_changed(token) # type: ignore[attr-defined] + elif ( + row.cache_func + == GET_E2E_CROSS_SIGNING_SIGNATURES_FOR_DEVICE_CACHE_NAME + ): + # "keys" is a list of strings, where each string is a + # JSON-encoded representation of the tuple keys, i.e. + # keys: ['["@userid:domain", "DEVICEID"]','["@userid2:domain", "DEVICEID2"]'] + # + # This is a side-effect of not being able to send nested + # information over replication. + for json_str in row.keys: + try: + user_id, device_id = json.loads(json_str) + except (json.JSONDecodeError, TypeError): + logger.error( + "Failed to deserialise cache key as valid JSON: %s", + json_str, + ) + continue + + # Invalidate each key. + # + # Note: .invalidate takes a tuple of arguments, hence the need + # to nest our tuple in another tuple. + self._get_e2e_cross_signing_signatures_for_device.invalidate( # type: ignore[attr-defined] + ((user_id, device_id),) + ) else: self._attempt_to_invalidate_cache(row.cache_func, row.keys) diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index de72e66ce..17ccefe6b 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -20,6 +20,7 @@ # # import abc +import json from typing import ( TYPE_CHECKING, Any, @@ -354,15 +355,17 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker ) for batch in batch_iter(signature_query, 50): - cross_sigs_result = await self.db_pool.runInteraction( - "get_e2e_cross_signing_signatures_for_devices", - self._get_e2e_cross_signing_signatures_for_devices_txn, - batch, + cross_sigs_result = ( + await self._get_e2e_cross_signing_signatures_for_devices(batch) ) # add each cross-signing signature to the correct device in the result dict. - for user_id, key_id, device_id, signature in cross_sigs_result: + for ( + user_id, + device_id, + ), signature_list in cross_sigs_result.items(): target_device_result = result[user_id][device_id] + # We've only looked up cross-signatures for non-deleted devices with key # data. assert target_device_result is not None @@ -373,7 +376,9 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker signing_user_signatures = target_device_signatures.setdefault( user_id, {} ) - signing_user_signatures[key_id] = signature + + for key_id, signature in signature_list: + signing_user_signatures[key_id] = signature log_kv(result) return result @@ -479,41 +484,83 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker return result - def _get_e2e_cross_signing_signatures_for_devices_txn( - self, txn: LoggingTransaction, device_query: Iterable[Tuple[str, str]] - ) -> List[Tuple[str, str, str, str]]: - """Get cross-signing signatures for a given list of devices - - Returns signatures made by the owners of the devices. - - Returns: a list of results; each entry in the list is a tuple of - (user_id, key_id, target_device_id, signature). + @cached() + def _get_e2e_cross_signing_signatures_for_device( + self, + user_id_and_device_id: Tuple[str, str], + ) -> Sequence[Tuple[str, str]]: """ - signature_query_clauses = [] - signature_query_params = [] + The single-item version of `_get_e2e_cross_signing_signatures_for_devices`. + See @cachedList for why a separate method is needed. + """ + raise NotImplementedError() - for user_id, device_id in device_query: - signature_query_clauses.append( - "target_user_id = ? AND target_device_id = ? AND user_id = ?" + @cachedList( + cached_method_name="_get_e2e_cross_signing_signatures_for_device", + list_name="device_query", + ) + async def _get_e2e_cross_signing_signatures_for_devices( + self, device_query: Iterable[Tuple[str, str]] + ) -> Mapping[Tuple[str, str], Sequence[Tuple[str, str]]]: + """Get cross-signing signatures for a given list of user IDs and devices. + + Args: + An iterable containing tuples of (user ID, device ID). + + Returns: + A mapping of results. The keys are the original (user_id, device_id) + tuple, while the value is the matching list of tuples of + (key_id, signature). The value will be an empty list if no + signatures exist for the device. + + Given this method is annotated with `@cachedList`, the return dict's + keys match the tuples within `device_query`, so that cache entries can + be computed from the corresponding values. + + As results are cached, the return type is immutable. + """ + + def _get_e2e_cross_signing_signatures_for_devices_txn( + txn: LoggingTransaction, device_query: Iterable[Tuple[str, str]] + ) -> Mapping[Tuple[str, str], Sequence[Tuple[str, str]]]: + where_clause_sql, where_clause_params = make_tuple_in_list_sql_clause( + self.database_engine, + columns=("target_user_id", "target_device_id", "user_id"), + iterable=[ + (user_id, device_id, user_id) for user_id, device_id in device_query + ], ) - signature_query_params.extend([user_id, device_id, user_id]) - signature_sql = """ - SELECT user_id, key_id, target_device_id, signature - FROM e2e_cross_signing_signatures WHERE %s - """ % (" OR ".join("(" + q + ")" for q in signature_query_clauses)) + signature_sql = f""" + SELECT user_id, key_id, target_device_id, signature + FROM e2e_cross_signing_signatures WHERE {where_clause_sql} + """ - txn.execute(signature_sql, signature_query_params) - return cast( - List[ - Tuple[ - str, - str, - str, - str, - ] - ], - txn.fetchall(), + txn.execute(signature_sql, where_clause_params) + + devices_and_signatures: Dict[Tuple[str, str], List[Tuple[str, str]]] = {} + + # `@cachedList` requires we return one key for every item in `device_query`. + # Pre-populate `devices_and_signatures` with each key so that none are missing. + # + # If any are missing, they will be cached as `None`, which is not + # what callers expected. + for user_id, device_id in device_query: + devices_and_signatures.setdefault((user_id, device_id), []) + + # Populate the return dictionary with each found key_id and signature. + for user_id, key_id, target_device_id, signature in txn.fetchall(): + signature_tuple = (key_id, signature) + devices_and_signatures[(user_id, target_device_id)].append( + signature_tuple + ) + + return devices_and_signatures + + return await self.db_pool.runInteraction( + "_get_e2e_cross_signing_signatures_for_devices_txn", + _get_e2e_cross_signing_signatures_for_devices_txn, + device_query, ) async def get_e2e_one_time_keys( @@ -1772,26 +1819,71 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker user_id: the user who made the signatures signatures: signatures to add """ - await self.db_pool.simple_insert_many( - "e2e_cross_signing_signatures", - keys=( - "user_id", - "key_id", - "target_user_id", - "target_device_id", - "signature", - ), - values=[ - ( - user_id, - item.signing_key_id, - item.target_user_id, - item.target_device_id, - item.signature, - ) + + def _store_e2e_cross_signing_signatures( + txn: LoggingTransaction, + signatures: "Iterable[SignatureListItem]", + ) -> None: + self.db_pool.simple_insert_many_txn( + txn, + "e2e_cross_signing_signatures", + keys=( + "user_id", + "key_id", + "target_user_id", + "target_device_id", + "signature", + ), + values=[ + ( + user_id, + item.signing_key_id, + item.target_user_id, + item.target_device_id, + item.signature, + ) + for item in signatures + ], + ) + + to_invalidate = [ + # Each entry is a tuple of arguments to + # `_get_e2e_cross_signing_signatures_for_device`, which + # itself takes a tuple. Hence the double-tuple. + ((user_id, item.target_device_id),) for item in signatures - ], - desc="add_e2e_signing_key", + ] + + if to_invalidate: + # Invalidate the local cache of this worker. + for cache_key in to_invalidate: + txn.call_after( + self._get_e2e_cross_signing_signatures_for_device.invalidate, + cache_key, + ) + + # Stream cache invalidate keys over replication. + # + # We can only send a primitive per function argument across + # replication. + # + # Encode the array of strings as a JSON string, and we'll unpack + # it on the other side. + to_send = [ + (json.dumps([user_id, item.target_device_id]),) + for item in signatures + ] + + self._send_invalidation_to_replication_bulk( + txn, + cache_name=self._get_e2e_cross_signing_signatures_for_device.__name__, + key_tuples=to_send, + ) + + await self.db_pool.runInteraction( + "add_e2e_signing_key", + _store_e2e_cross_signing_signatures, + signatures, ) diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 9630cd6d2..47b8f4ddc 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -579,9 +579,12 @@ def cachedList( Used to do batch lookups for an already created cache. One of the arguments is specified as a list that is iterated through to lookup keys in the original cache. A new tuple consisting of the (deduplicated) keys that weren't in - the cache gets passed to the original function, which is expected to results + the cache gets passed to the original function, which is expected to result in a map of key to value for each passed value. The new results are stored in the - original cache. Note that any missing values are cached as None. + original cache. + + Note that any values in the input that end up being missing from both the + cache and the returned dictionary will be cached as `None`. Args: cached_method_name: The name of the single-item lookup method.