diff --git a/changelog.d/18756.misc b/changelog.d/18756.misc new file mode 100644 index 000000000..c4353f776 --- /dev/null +++ b/changelog.d/18756.misc @@ -0,0 +1 @@ +Update implementation of [MSC4306: Thread Subscriptions](https://github.com/matrix-org/matrix-doc/issues/4306) to include automatic subscription conflict prevention as introduced in later drafts. \ No newline at end of file diff --git a/changelog.d/18762.feature b/changelog.d/18762.feature new file mode 100644 index 000000000..aa8e91de0 --- /dev/null +++ b/changelog.d/18762.feature @@ -0,0 +1 @@ +Implement the push rules for experimental [MSC4306: Thread Subscriptions](https://github.com/matrix-org/matrix-doc/issues/4306). \ No newline at end of file diff --git a/rust/benches/evaluator.rs b/rust/benches/evaluator.rs index 28537e187..96169fd45 100644 --- a/rust/benches/evaluator.rs +++ b/rust/benches/evaluator.rs @@ -61,6 +61,7 @@ fn bench_match_exact(b: &mut Bencher) { vec![], false, false, + false, ) .unwrap(); @@ -71,10 +72,10 @@ fn bench_match_exact(b: &mut Bencher) { }, )); - let matched = eval.match_condition(&condition, None, None).unwrap(); + let matched = eval.match_condition(&condition, None, None, None).unwrap(); assert!(matched, "Didn't match"); - b.iter(|| eval.match_condition(&condition, None, None).unwrap()); + b.iter(|| eval.match_condition(&condition, None, None, None).unwrap()); } #[bench] @@ -107,6 +108,7 @@ fn bench_match_word(b: &mut Bencher) { vec![], false, false, + false, ) .unwrap(); @@ -117,10 +119,10 @@ fn bench_match_word(b: &mut Bencher) { }, )); - let matched = eval.match_condition(&condition, None, None).unwrap(); + let matched = eval.match_condition(&condition, None, None, None).unwrap(); assert!(matched, "Didn't match"); - b.iter(|| eval.match_condition(&condition, None, None).unwrap()); + b.iter(|| eval.match_condition(&condition, None, None, None).unwrap()); } #[bench] @@ -153,6 +155,7 @@ fn bench_match_word_miss(b: &mut Bencher) { vec![], false, false, + false, ) .unwrap(); @@ -163,10 +166,10 @@ fn bench_match_word_miss(b: &mut Bencher) { }, )); - let matched = eval.match_condition(&condition, None, None).unwrap(); + let matched = eval.match_condition(&condition, None, None, None).unwrap(); assert!(!matched, "Didn't match"); - b.iter(|| eval.match_condition(&condition, None, None).unwrap()); + b.iter(|| eval.match_condition(&condition, None, None, None).unwrap()); } #[bench] @@ -199,6 +202,7 @@ fn bench_eval_message(b: &mut Bencher) { vec![], false, false, + false, ) .unwrap(); @@ -210,7 +214,8 @@ fn bench_eval_message(b: &mut Bencher) { false, false, false, + false, ); - b.iter(|| eval.run(&rules, Some("bob"), Some("person"))); + b.iter(|| eval.run(&rules, Some("bob"), Some("person"), None)); } diff --git a/rust/src/push/base_rules.rs b/rust/src/push/base_rules.rs index e0832ada1..ec027ca25 100644 --- a/rust/src/push/base_rules.rs +++ b/rust/src/push/base_rules.rs @@ -290,6 +290,26 @@ pub const BASE_APPEND_CONTENT_RULES: &[PushRule] = &[PushRule { }]; pub const BASE_APPEND_UNDERRIDE_RULES: &[PushRule] = &[ + PushRule { + rule_id: Cow::Borrowed("global/content/.io.element.msc4306.rule.unsubscribed_thread"), + priority_class: 1, + conditions: Cow::Borrowed(&[Condition::Known( + KnownCondition::Msc4306ThreadSubscription { subscribed: false }, + )]), + actions: Cow::Borrowed(&[]), + default: true, + default_enabled: true, + }, + PushRule { + rule_id: Cow::Borrowed("global/content/.io.element.msc4306.rule.subscribed_thread"), + priority_class: 1, + conditions: Cow::Borrowed(&[Condition::Known( + KnownCondition::Msc4306ThreadSubscription { subscribed: true }, + )]), + actions: Cow::Borrowed(&[Action::Notify, SOUND_ACTION]), + default: true, + default_enabled: true, + }, PushRule { rule_id: Cow::Borrowed("global/underride/.m.rule.call"), priority_class: 1, diff --git a/rust/src/push/evaluator.rs b/rust/src/push/evaluator.rs index db406acb8..1cbca4c63 100644 --- a/rust/src/push/evaluator.rs +++ b/rust/src/push/evaluator.rs @@ -106,8 +106,11 @@ pub struct PushRuleEvaluator { /// flag as MSC1767 (extensible events core). msc3931_enabled: bool, - // If MSC4210 (remove legacy mentions) is enabled. + /// If MSC4210 (remove legacy mentions) is enabled. msc4210_enabled: bool, + + /// If MSC4306 (thread subscriptions) is enabled. + msc4306_enabled: bool, } #[pymethods] @@ -126,6 +129,7 @@ impl PushRuleEvaluator { room_version_feature_flags, msc3931_enabled, msc4210_enabled, + msc4306_enabled, ))] pub fn py_new( flattened_keys: BTreeMap, @@ -138,6 +142,7 @@ impl PushRuleEvaluator { room_version_feature_flags: Vec, msc3931_enabled: bool, msc4210_enabled: bool, + msc4306_enabled: bool, ) -> Result { let body = match flattened_keys.get("content.body") { Some(JsonValue::Value(SimpleJsonValue::Str(s))) => s.clone().into_owned(), @@ -156,6 +161,7 @@ impl PushRuleEvaluator { room_version_feature_flags, msc3931_enabled, msc4210_enabled, + msc4306_enabled, }) } @@ -167,12 +173,19 @@ impl PushRuleEvaluator { /// /// Returns the set of actions, if any, that match (filtering out any /// `dont_notify` and `coalesce` actions). - #[pyo3(signature = (push_rules, user_id=None, display_name=None))] + /// + /// msc4306_thread_subscription_state: (Only populated if MSC4306 is enabled) + /// The thread subscription state corresponding to the thread containing this event. + /// - `None` if the event is not in a thread, or if MSC4306 is disabled. + /// - `Some(true)` if the event is in a thread and the user has a subscription for that thread + /// - `Some(false)` if the event is in a thread and the user does NOT have a subscription for that thread + #[pyo3(signature = (push_rules, user_id=None, display_name=None, msc4306_thread_subscription_state=None))] pub fn run( &self, push_rules: &FilteredPushRules, user_id: Option<&str>, display_name: Option<&str>, + msc4306_thread_subscription_state: Option, ) -> Vec { 'outer: for (push_rule, enabled) in push_rules.iter() { if !enabled { @@ -204,7 +217,12 @@ impl PushRuleEvaluator { Condition::Known(KnownCondition::RoomVersionSupports { feature: _ }), ); - match self.match_condition(condition, user_id, display_name) { + match self.match_condition( + condition, + user_id, + display_name, + msc4306_thread_subscription_state, + ) { Ok(true) => {} Ok(false) => continue 'outer, Err(err) => { @@ -237,14 +255,20 @@ impl PushRuleEvaluator { } /// Check if the given condition matches. - #[pyo3(signature = (condition, user_id=None, display_name=None))] + #[pyo3(signature = (condition, user_id=None, display_name=None, msc4306_thread_subscription_state=None))] fn matches( &self, condition: Condition, user_id: Option<&str>, display_name: Option<&str>, + msc4306_thread_subscription_state: Option, ) -> bool { - match self.match_condition(&condition, user_id, display_name) { + match self.match_condition( + &condition, + user_id, + display_name, + msc4306_thread_subscription_state, + ) { Ok(true) => true, Ok(false) => false, Err(err) => { @@ -262,6 +286,7 @@ impl PushRuleEvaluator { condition: &Condition, user_id: Option<&str>, display_name: Option<&str>, + msc4306_thread_subscription_state: Option, ) -> Result { let known_condition = match condition { Condition::Known(known) => known, @@ -393,6 +418,13 @@ impl PushRuleEvaluator { && self.room_version_feature_flags.contains(&flag) } } + KnownCondition::Msc4306ThreadSubscription { subscribed } => { + if !self.msc4306_enabled { + false + } else { + msc4306_thread_subscription_state == Some(*subscribed) + } + } }; Ok(result) @@ -536,10 +568,11 @@ fn push_rule_evaluator() { vec![], true, false, + false, ) .unwrap(); - let result = evaluator.run(&FilteredPushRules::default(), None, Some("bob")); + let result = evaluator.run(&FilteredPushRules::default(), None, Some("bob"), None); assert_eq!(result.len(), 3); } @@ -566,6 +599,7 @@ fn test_requires_room_version_supports_condition() { flags, true, false, + false, ) .unwrap(); @@ -575,6 +609,7 @@ fn test_requires_room_version_supports_condition() { &FilteredPushRules::default(), Some("@bob:example.org"), None, + None, ); assert_eq!(result.len(), 3); @@ -593,7 +628,17 @@ fn test_requires_room_version_supports_condition() { }; let rules = PushRules::new(vec![custom_rule]); result = evaluator.run( - &FilteredPushRules::py_new(rules, BTreeMap::new(), true, false, true, false, false), + &FilteredPushRules::py_new( + rules, + BTreeMap::new(), + true, + false, + true, + false, + false, + false, + ), + None, None, None, ); diff --git a/rust/src/push/mod.rs b/rust/src/push/mod.rs index bd0e853ac..b07a12e5c 100644 --- a/rust/src/push/mod.rs +++ b/rust/src/push/mod.rs @@ -369,6 +369,10 @@ pub enum KnownCondition { RoomVersionSupports { feature: Cow<'static, str>, }, + #[serde(rename = "io.element.msc4306.thread_subscription")] + Msc4306ThreadSubscription { + subscribed: bool, + }, } impl<'source> IntoPyObject<'source> for Condition { @@ -547,11 +551,13 @@ pub struct FilteredPushRules { msc3664_enabled: bool, msc4028_push_encrypted_events: bool, msc4210_enabled: bool, + msc4306_enabled: bool, } #[pymethods] impl FilteredPushRules { #[new] + #[allow(clippy::too_many_arguments)] pub fn py_new( push_rules: PushRules, enabled_map: BTreeMap, @@ -560,6 +566,7 @@ impl FilteredPushRules { msc3664_enabled: bool, msc4028_push_encrypted_events: bool, msc4210_enabled: bool, + msc4306_enabled: bool, ) -> Self { Self { push_rules, @@ -569,6 +576,7 @@ impl FilteredPushRules { msc3664_enabled, msc4028_push_encrypted_events, msc4210_enabled, + msc4306_enabled, } } @@ -619,6 +627,10 @@ impl FilteredPushRules { return false; } + if !self.msc4306_enabled && rule.rule_id.contains("/.io.element.msc4306.rule.") { + return false; + } + true }) .map(|r| { diff --git a/synapse/api/errors.py b/synapse/api/errors.py index b832c2f6a..ec4d707b7 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -140,6 +140,12 @@ class Codes(str, Enum): # Part of MSC4155 INVITE_BLOCKED = "ORG.MATRIX.MSC4155.M_INVITE_BLOCKED" + # Part of MSC4306: Thread Subscriptions + MSC4306_CONFLICTING_UNSUBSCRIPTION = ( + "IO.ELEMENT.MSC4306.M_CONFLICTING_UNSUBSCRIPTION" + ) + MSC4306_NOT_IN_THREAD = "IO.ELEMENT.MSC4306.M_NOT_IN_THREAD" + class CodeMessageException(RuntimeError): """An exception with integer code, a message string attributes and optional headers. diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py index 1e9722e0d..7f511d570 100644 --- a/synapse/federation/send_queue.py +++ b/synapse/federation/send_queue.py @@ -37,7 +37,6 @@ Events are replicated via a separate events stream. """ import logging -from enum import Enum from typing import ( TYPE_CHECKING, Dict, @@ -68,25 +67,6 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -class QueueNames(str, Enum): - PRESENCE_MAP = "presence_map" - KEYED_EDU = "keyed_edu" - KEYED_EDU_CHANGED = "keyed_edu_changed" - EDUS = "edus" - POS_TIME = "pos_time" - PRESENCE_DESTINATIONS = "presence_destinations" - - -queue_name_to_gauge_map: Dict[QueueNames, LaterGauge] = {} - -for queue_name in QueueNames: - queue_name_to_gauge_map[queue_name] = LaterGauge( - name=f"synapse_federation_send_queue_{queue_name.value}_size", - desc="", - labelnames=[SERVER_NAME_LABEL], - ) - - class FederationRemoteSendQueue(AbstractFederationSender): """A drop in replacement for FederationSender""" @@ -131,15 +111,23 @@ class FederationRemoteSendQueue(AbstractFederationSender): # we make a new function, so we need to make a new function so the inner # lambda binds to the queue rather than to the name of the queue which # changes. ARGH. - def register(queue_name: QueueNames, queue: Sized) -> None: - queue_name_to_gauge_map[queue_name].register_hook( - lambda: {(self.server_name,): len(queue)} + def register(name: str, queue: Sized) -> None: + LaterGauge( + name="synapse_federation_send_queue_%s_size" % (queue_name,), + desc="", + labelnames=[SERVER_NAME_LABEL], + caller=lambda: {(self.server_name,): len(queue)}, ) - for queue_name in QueueNames: - queue = getattr(self, queue_name.value) - assert isinstance(queue, Sized) - register(queue_name, queue=queue) + for queue_name in [ + "presence_map", + "keyed_edu", + "keyed_edu_changed", + "edus", + "pos_time", + "presence_destinations", + ]: + register(queue_name, getattr(self, queue_name)) self.clock.looping_call(self._clear_queue, 30 * 1000) diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index 21af12354..8befbe372 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -199,24 +199,6 @@ sent_pdus_destination_dist_total = Counter( labelnames=[SERVER_NAME_LABEL], ) -transaction_queue_pending_destinations_gauge = LaterGauge( - name="synapse_federation_transaction_queue_pending_destinations", - desc="", - labelnames=[SERVER_NAME_LABEL], -) - -transaction_queue_pending_pdus_gauge = LaterGauge( - name="synapse_federation_transaction_queue_pending_pdus", - desc="", - labelnames=[SERVER_NAME_LABEL], -) - -transaction_queue_pending_edus_gauge = LaterGauge( - name="synapse_federation_transaction_queue_pending_edus", - desc="", - labelnames=[SERVER_NAME_LABEL], -) - # Time (in s) to wait before trying to wake up destinations that have # catch-up outstanding. # Please note that rate limiting still applies, so while the loop is @@ -416,28 +398,38 @@ class FederationSender(AbstractFederationSender): # map from destination to PerDestinationQueue self._per_destination_queues: Dict[str, PerDestinationQueue] = {} - transaction_queue_pending_destinations_gauge.register_hook( - lambda: { + LaterGauge( + name="synapse_federation_transaction_queue_pending_destinations", + desc="", + labelnames=[SERVER_NAME_LABEL], + caller=lambda: { (self.server_name,): sum( 1 for d in self._per_destination_queues.values() if d.transmission_loop_running ) - } + }, ) - transaction_queue_pending_pdus_gauge.register_hook( - lambda: { + + LaterGauge( + name="synapse_federation_transaction_queue_pending_pdus", + desc="", + labelnames=[SERVER_NAME_LABEL], + caller=lambda: { (self.server_name,): sum( d.pending_pdu_count() for d in self._per_destination_queues.values() ) - } + }, ) - transaction_queue_pending_edus_gauge.register_hook( - lambda: { + LaterGauge( + name="synapse_federation_transaction_queue_pending_edus", + desc="", + labelnames=[SERVER_NAME_LABEL], + caller=lambda: { (self.server_name,): sum( d.pending_edu_count() for d in self._per_destination_queues.values() ) - } + }, ) self._is_processing = False diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index fb9f96267..b25311749 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -173,18 +173,6 @@ state_transition_counter = Counter( labelnames=["locality", "from", "to", SERVER_NAME_LABEL], ) -presence_user_to_current_state_size_gauge = LaterGauge( - name="synapse_handlers_presence_user_to_current_state_size", - desc="", - labelnames=[SERVER_NAME_LABEL], -) - -presence_wheel_timer_size_gauge = LaterGauge( - name="synapse_handlers_presence_wheel_timer_size", - desc="", - labelnames=[SERVER_NAME_LABEL], -) - # If a user was last active in the last LAST_ACTIVE_GRANULARITY, consider them # "currently_active" LAST_ACTIVE_GRANULARITY = 60 * 1000 @@ -791,8 +779,11 @@ class PresenceHandler(BasePresenceHandler): EduTypes.PRESENCE, self.incoming_presence ) - presence_user_to_current_state_size_gauge.register_hook( - lambda: {(self.server_name,): len(self.user_to_current_state)} + LaterGauge( + name="synapse_handlers_presence_user_to_current_state_size", + desc="", + labelnames=[SERVER_NAME_LABEL], + caller=lambda: {(self.server_name,): len(self.user_to_current_state)}, ) # The per-device presence state, maps user to devices to per-device presence state. @@ -891,8 +882,11 @@ class PresenceHandler(BasePresenceHandler): 60 * 1000, ) - presence_wheel_timer_size_gauge.register_hook( - lambda: {(self.server_name,): len(self.wheel_timer)} + LaterGauge( + name="synapse_handlers_presence_wheel_timer_size", + desc="", + labelnames=[SERVER_NAME_LABEL], + caller=lambda: {(self.server_name,): len(self.wheel_timer)}, ) # Used to handle sending of presence to newly joined users/servers diff --git a/synapse/handlers/thread_subscriptions.py b/synapse/handlers/thread_subscriptions.py index 79e4d6040..bda434294 100644 --- a/synapse/handlers/thread_subscriptions.py +++ b/synapse/handlers/thread_subscriptions.py @@ -1,9 +1,15 @@ import logging +from http import HTTPStatus from typing import TYPE_CHECKING, Optional -from synapse.api.errors import AuthError, NotFoundError -from synapse.storage.databases.main.thread_subscriptions import ThreadSubscription -from synapse.types import UserID +from synapse.api.constants import RelationTypes +from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError +from synapse.events import relation_from_event +from synapse.storage.databases.main.thread_subscriptions import ( + AutomaticSubscriptionConflicted, + ThreadSubscription, +) +from synapse.types import EventOrderings, UserID if TYPE_CHECKING: from synapse.server import HomeServer @@ -55,42 +61,79 @@ class ThreadSubscriptionsHandler: room_id: str, thread_root_event_id: str, *, - automatic: bool, + automatic_event_id: Optional[str], ) -> Optional[int]: """Sets or updates a user's subscription settings for a specific thread root. Args: requester_user_id: The ID of the user whose settings are being updated. thread_root_event_id: The event ID of the thread root. - automatic: whether the user was subscribed by an automatic decision by - their client. + automatic_event_id: if the user was subscribed by an automatic decision by + their client, the event ID that caused this. Returns: The stream ID for this update, if the update isn't no-opped. Raises: NotFoundError if the user cannot access the thread root event, or it isn't - known to this homeserver. + known to this homeserver. Ditto for the automatic cause event if supplied. + + SynapseError(400, M_NOT_IN_THREAD): if client supplied an automatic cause event + but user cannot access the event. + + SynapseError(409, M_SKIPPED): if client requested an automatic subscription + but it was skipped because the cause event is logically later than an unsubscription. """ # First check that the user can access the thread root event # and that it exists try: - event = await self.event_handler.get_event( + thread_root_event = await self.event_handler.get_event( user_id, room_id, thread_root_event_id ) - if event is None: + if thread_root_event is None: raise NotFoundError("No such thread root") except AuthError: logger.info("rejecting thread subscriptions change (thread not accessible)") raise NotFoundError("No such thread root") - return await self.store.subscribe_user_to_thread( + if automatic_event_id: + autosub_cause_event = await self.event_handler.get_event( + user_id, room_id, automatic_event_id + ) + if autosub_cause_event is None: + raise NotFoundError("Automatic subscription event not found") + relation = relation_from_event(autosub_cause_event) + if ( + relation is None + or relation.rel_type != RelationTypes.THREAD + or relation.parent_id != thread_root_event_id + ): + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "Automatic subscription must use an event in the thread", + errcode=Codes.MSC4306_NOT_IN_THREAD, + ) + + automatic_event_orderings = EventOrderings.from_event(autosub_cause_event) + else: + automatic_event_orderings = None + + outcome = await self.store.subscribe_user_to_thread( user_id.to_string(), - event.room_id, + room_id, thread_root_event_id, - automatic=automatic, + automatic_event_orderings=automatic_event_orderings, ) + if isinstance(outcome, AutomaticSubscriptionConflicted): + raise SynapseError( + HTTPStatus.CONFLICT, + "Automatic subscription obsoleted by an unsubscription request.", + errcode=Codes.MSC4306_CONFLICTING_UNSUBSCRIPTION, + ) + + return outcome + async def unsubscribe_user_from_thread( self, user_id: UserID, room_id: str, thread_root_event_id: str ) -> Optional[int]: diff --git a/synapse/http/request_metrics.py b/synapse/http/request_metrics.py index c5274c758..a9b049f90 100644 --- a/synapse/http/request_metrics.py +++ b/synapse/http/request_metrics.py @@ -164,12 +164,12 @@ def _get_in_flight_counts() -> Mapping[Tuple[str, ...], int]: return counts -in_flight_requests = LaterGauge( +LaterGauge( name="synapse_http_server_in_flight_requests_count", desc="", labelnames=["method", "servlet", SERVER_NAME_LABEL], + caller=_get_in_flight_counts, ) -in_flight_requests.register_hook(_get_in_flight_counts) class RequestMetrics: diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py index 8c99d3c77..11e2551a1 100644 --- a/synapse/metrics/__init__.py +++ b/synapse/metrics/__init__.py @@ -31,7 +31,6 @@ from typing import ( Dict, Generic, Iterable, - List, Mapping, Optional, Sequence, @@ -74,6 +73,8 @@ logger = logging.getLogger(__name__) METRICS_PREFIX = "/_synapse/metrics" +all_gauges: Dict[str, Collector] = {} + HAVE_PROC_SELF_STAT = os.path.exists("/proc/self/stat") SERVER_NAME_LABEL = "server_name" @@ -162,47 +163,42 @@ class LaterGauge(Collector): name: str desc: str labelnames: Optional[StrSequence] = attr.ib(hash=False) - # List of callbacks: each callback should either return a value (if there are no - # labels for this metric), or dict mapping from a label tuple to a value - _hooks: List[ - Callable[ - [], Union[Mapping[Tuple[str, ...], Union[int, float]], Union[int, float]] - ] - ] = attr.ib(factory=list, hash=False) + # callback: should either return a value (if there are no labels for this metric), + # or dict mapping from a label tuple to a value + caller: Callable[ + [], Union[Mapping[Tuple[str, ...], Union[int, float]], Union[int, float]] + ] def collect(self) -> Iterable[Metric]: # The decision to add `SERVER_NAME_LABEL` is from the `LaterGauge` usage itself # (we don't enforce it here, one level up). g = GaugeMetricFamily(self.name, self.desc, labels=self.labelnames) # type: ignore[missing-server-name-label] - for hook in self._hooks: - try: - hook_result = hook() - except Exception: - logger.exception( - "Exception running callback for LaterGauge(%s)", self.name - ) - yield g - return - - if isinstance(hook_result, (int, float)): - g.add_metric([], hook_result) - else: - for k, v in hook_result.items(): - g.add_metric(k, v) - + try: + calls = self.caller() + except Exception: + logger.exception("Exception running callback for LaterGauge(%s)", self.name) yield g + return - def register_hook( - self, - hook: Callable[ - [], Union[Mapping[Tuple[str, ...], Union[int, float]], Union[int, float]] - ], - ) -> None: - self._hooks.append(hook) + if isinstance(calls, (int, float)): + g.add_metric([], calls) + else: + for k, v in calls.items(): + g.add_metric(k, v) + + yield g def __attrs_post_init__(self) -> None: + self._register() + + def _register(self) -> None: + if self.name in all_gauges.keys(): + logger.warning("%s already registered, reregistering", self.name) + REGISTRY.unregister(all_gauges.pop(self.name)) + REGISTRY.register(self) + all_gauges[self.name] = self # `MetricsEntry` only makes sense when it is a `Protocol`, @@ -254,7 +250,7 @@ class InFlightGauge(Generic[MetricsEntry], Collector): # Protects access to _registrations self._lock = threading.Lock() - REGISTRY.register(self) + self._register_with_collector() def register( self, @@ -345,6 +341,14 @@ class InFlightGauge(Generic[MetricsEntry], Collector): gauge.add_metric(labels=key, value=getattr(metrics, name)) yield gauge + def _register_with_collector(self) -> None: + if self.name in all_gauges.keys(): + logger.warning("%s already registered, reregistering", self.name) + REGISTRY.unregister(all_gauges.pop(self.name)) + + REGISTRY.register(self) + all_gauges[self.name] = self + class GaugeHistogramMetricFamilyWithLabels(GaugeHistogramMetricFamily): """ diff --git a/synapse/notifier.py b/synapse/notifier.py index d56a7b26b..448a715e2 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -86,24 +86,6 @@ users_woken_by_stream_counter = Counter( labelnames=["stream", SERVER_NAME_LABEL], ) - -notifier_listeners_gauge = LaterGauge( - name="synapse_notifier_listeners", - desc="", - labelnames=[SERVER_NAME_LABEL], -) - -notifier_rooms_gauge = LaterGauge( - name="synapse_notifier_rooms", - desc="", - labelnames=[SERVER_NAME_LABEL], -) -notifier_users_gauge = LaterGauge( - name="synapse_notifier_users", - desc="", - labelnames=[SERVER_NAME_LABEL], -) - T = TypeVar("T") @@ -299,16 +281,28 @@ class Notifier: ) } - notifier_listeners_gauge.register_hook(count_listeners) - notifier_rooms_gauge.register_hook( - lambda: { + LaterGauge( + name="synapse_notifier_listeners", + desc="", + labelnames=[SERVER_NAME_LABEL], + caller=count_listeners, + ) + + LaterGauge( + name="synapse_notifier_rooms", + desc="", + labelnames=[SERVER_NAME_LABEL], + caller=lambda: { (self.server_name,): count( bool, list(self.room_to_user_streams.values()) ) - } + }, ) - notifier_users_gauge.register_hook( - lambda: {(self.server_name,): len(self.user_to_user_stream)} + LaterGauge( + name="synapse_notifier_users", + desc="", + labelnames=[SERVER_NAME_LABEL], + caller=lambda: {(self.server_name,): len(self.user_to_user_stream)}, ) def add_replication_callback(self, cb: Callable[[], None]) -> None: diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index da4fa29da..bb9d5dbca 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -25,6 +25,7 @@ from typing import ( Any, Collection, Dict, + FrozenSet, List, Mapping, Optional, @@ -477,8 +478,18 @@ class BulkPushRuleEvaluator: event.room_version.msc3931_push_features, self.hs.config.experimental.msc1767_enabled, # MSC3931 flag self.hs.config.experimental.msc4210_enabled, + self.hs.config.experimental.msc4306_enabled, ) + msc4306_thread_subscribers: Optional[FrozenSet[str]] = None + if self.hs.config.experimental.msc4306_enabled and thread_id != MAIN_TIMELINE: + # pull out, in batch, all local subscribers to this thread + # (in the common case, they will all be getting processed for push + # rules right now) + msc4306_thread_subscribers = await self.store.get_subscribers_to_thread( + event.room_id, thread_id + ) + for uid, rules in rules_by_user.items(): if event.sender == uid: continue @@ -503,7 +514,13 @@ class BulkPushRuleEvaluator: # current user, it'll be added to the dict later. actions_by_user[uid] = [] - actions = evaluator.run(rules, uid, display_name) + msc4306_thread_subscription_state: Optional[bool] = None + if msc4306_thread_subscribers is not None: + msc4306_thread_subscription_state = uid in msc4306_thread_subscribers + + actions = evaluator.run( + rules, uid, display_name, msc4306_thread_subscription_state + ) if "notify" in actions: # Push rules say we should notify the user of this event actions_by_user[uid] = actions diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index f033eaaeb..0f14c7e38 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -106,18 +106,6 @@ user_ip_cache_counter = Counter( "synapse_replication_tcp_resource_user_ip_cache", "", labelnames=[SERVER_NAME_LABEL] ) -tcp_resource_total_connections_gauge = LaterGauge( - name="synapse_replication_tcp_resource_total_connections", - desc="", - labelnames=[SERVER_NAME_LABEL], -) - -tcp_command_queue_gauge = LaterGauge( - name="synapse_replication_tcp_command_queue", - desc="Number of inbound RDATA/POSITION commands queued for processing", - labelnames=["stream_name", SERVER_NAME_LABEL], -) - # the type of the entries in _command_queues_by_stream _StreamCommandQueue = Deque[ @@ -255,8 +243,11 @@ class ReplicationCommandHandler: # outgoing replication commands to.) self._connections: List[IReplicationConnection] = [] - tcp_resource_total_connections_gauge.register_hook( - lambda: {(self.server_name,): len(self._connections)} + LaterGauge( + name="synapse_replication_tcp_resource_total_connections", + desc="", + labelnames=[SERVER_NAME_LABEL], + caller=lambda: {(self.server_name,): len(self._connections)}, ) # When POSITION or RDATA commands arrive, we stick them in a queue and process @@ -275,11 +266,14 @@ class ReplicationCommandHandler: # from that connection. self._streams_by_connection: Dict[IReplicationConnection, Set[str]] = {} - tcp_command_queue_gauge.register_hook( - lambda: { + LaterGauge( + name="synapse_replication_tcp_command_queue", + desc="Number of inbound RDATA/POSITION commands queued for processing", + labelnames=["stream_name", SERVER_NAME_LABEL], + caller=lambda: { (stream_name, self.server_name): len(queue) for stream_name, queue in self._command_queues_by_stream.items() - } + }, ) self._is_master = hs.config.worker.worker_app is None diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index 4d8381646..969f0303e 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -527,11 +527,9 @@ pending_commands = LaterGauge( name="synapse_replication_tcp_protocol_pending_commands", desc="", labelnames=["name", SERVER_NAME_LABEL], -) -pending_commands.register_hook( - lambda: { + caller=lambda: { (p.name, p.server_name): len(p.pending_commands) for p in connected_connections - } + }, ) @@ -546,11 +544,9 @@ transport_send_buffer = LaterGauge( name="synapse_replication_tcp_protocol_transport_send_buffer", desc="", labelnames=["name", SERVER_NAME_LABEL], -) -transport_send_buffer.register_hook( - lambda: { + caller=lambda: { (p.name, p.server_name): transport_buffer_size(p) for p in connected_connections - } + }, ) @@ -575,12 +571,10 @@ tcp_transport_kernel_send_buffer = LaterGauge( name="synapse_replication_tcp_protocol_transport_kernel_send_buffer", desc="", labelnames=["name", SERVER_NAME_LABEL], -) -tcp_transport_kernel_send_buffer.register_hook( - lambda: { + caller=lambda: { (p.name, p.server_name): transport_kernel_read_buffer_size(p, False) for p in connected_connections - } + }, ) @@ -588,10 +582,8 @@ tcp_transport_kernel_read_buffer = LaterGauge( name="synapse_replication_tcp_protocol_transport_kernel_read_buffer", desc="", labelnames=["name", SERVER_NAME_LABEL], -) -tcp_transport_kernel_read_buffer.register_hook( - lambda: { + caller=lambda: { (p.name, p.server_name): transport_kernel_read_buffer_size(p, True) for p in connected_connections - } + }, ) diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index 9694fff4f..ec7e935d6 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -739,7 +739,7 @@ class ThreadSubscriptionsStream(_StreamFromIdGen): NAME = "thread_subscriptions" ROW_TYPE = ThreadSubscriptionsStreamRow - def __init__(self, hs: Any): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastores().main super().__init__( hs.get_instance_name(), @@ -751,7 +751,7 @@ class ThreadSubscriptionsStream(_StreamFromIdGen): self, instance_name: str, from_token: int, to_token: int, limit: int ) -> StreamUpdateResult: updates = await self.store.get_updated_thread_subscriptions( - from_token, to_token, limit + from_id=from_token, to_id=to_token, limit=limit ) rows = [ ( diff --git a/synapse/rest/client/thread_subscriptions.py b/synapse/rest/client/thread_subscriptions.py index eb724500b..4e7b5d06d 100644 --- a/synapse/rest/client/thread_subscriptions.py +++ b/synapse/rest/client/thread_subscriptions.py @@ -1,7 +1,6 @@ from http import HTTPStatus -from typing import TYPE_CHECKING, Tuple +from typing import TYPE_CHECKING, Optional, Tuple -from synapse._pydantic_compat import StrictBool from synapse.api.errors import Codes, NotFoundError, SynapseError from synapse.http.server import HttpServer from synapse.http.servlet import ( @@ -12,6 +11,7 @@ from synapse.http.site import SynapseRequest from synapse.rest.client._base import client_patterns from synapse.types import JsonDict, RoomID from synapse.types.rest import RequestBodyModel +from synapse.util.pydantic_models import AnyEventId if TYPE_CHECKING: from synapse.server import HomeServer @@ -32,7 +32,12 @@ class ThreadSubscriptionsRestServlet(RestServlet): self.handler = hs.get_thread_subscriptions_handler() class PutBody(RequestBodyModel): - automatic: StrictBool + automatic: Optional[AnyEventId] + """ + If supplied, the event ID of an event giving rise to this automatic subscription. + + If omitted, this subscription is a manual subscription. + """ async def on_GET( self, request: SynapseRequest, room_id: str, thread_root_id: str @@ -63,15 +68,15 @@ class ThreadSubscriptionsRestServlet(RestServlet): raise SynapseError( HTTPStatus.BAD_REQUEST, "Invalid event ID", errcode=Codes.INVALID_PARAM ) - requester = await self.auth.get_user_by_req(request) - body = parse_and_validate_json_object_from_request(request, self.PutBody) + requester = await self.auth.get_user_by_req(request) + await self.handler.subscribe_user_to_thread( requester.user, room_id, thread_root_id, - automatic=body.automatic, + automatic_event_id=body.automatic, ) return HTTPStatus.OK, {} diff --git a/synapse/storage/database.py b/synapse/storage/database.py index bbdc5b9d2..f7aec16c9 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -100,12 +100,6 @@ sql_txn_duration = Counter( labelnames=["desc", SERVER_NAME_LABEL], ) -background_update_status = LaterGauge( - name="synapse_background_update_status", - desc="Background update status", - labelnames=[SERVER_NAME_LABEL], -) - # Unique indexes which have been added in background updates. Maps from table name # to the name of the background update which added the unique index to that table. @@ -617,8 +611,11 @@ class DatabasePool: ) self.updates = BackgroundUpdater(hs, self) - background_update_status.register_hook( - lambda: {(self.server_name,): self.updates.get_status()}, + LaterGauge( + name="synapse_background_update_status", + desc="Background update status", + labelnames=[SERVER_NAME_LABEL], + caller=lambda: {(self.server_name,): self.updates.get_status()}, ) self._previous_txn_total_time = 0.0 diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index 22948f8c2..d68614055 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -110,6 +110,7 @@ def _load_rules( msc3381_polls_enabled=experimental_config.msc3381_polls_enabled, msc4028_push_encrypted_events=experimental_config.msc4028_push_encrypted_events, msc4210_enabled=experimental_config.msc4210_enabled, + msc4306_enabled=experimental_config.msc4306_enabled, ) return filtered_rules diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 94a1274ed..654250fad 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -84,13 +84,6 @@ _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME = "current_state_events_membership" _POPULATE_PARTICIPANT_BG_UPDATE_BATCH_SIZE = 1000 -federation_known_servers_gauge = LaterGauge( - name="synapse_federation_known_servers", - desc="", - labelnames=[SERVER_NAME_LABEL], -) - - @attr.s(frozen=True, slots=True, auto_attribs=True) class EventIdMembership: """Returned by `get_membership_from_event_ids`""" @@ -123,8 +116,11 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): 1, self._count_known_servers, ) - federation_known_servers_gauge.register_hook( - lambda: {(self.server_name,): self._known_servers_count} + LaterGauge( + name="synapse_federation_known_servers", + desc="", + labelnames=[SERVER_NAME_LABEL], + caller=lambda: {(self.server_name,): self._known_servers_count}, ) @wrap_as_background_process("_count_known_servers") diff --git a/synapse/storage/databases/main/thread_subscriptions.py b/synapse/storage/databases/main/thread_subscriptions.py index 4933224f0..24a99cf44 100644 --- a/synapse/storage/databases/main/thread_subscriptions.py +++ b/synapse/storage/databases/main/thread_subscriptions.py @@ -14,7 +14,7 @@ import logging from typing import ( TYPE_CHECKING, Any, - Dict, + FrozenSet, Iterable, List, Optional, @@ -33,6 +33,7 @@ from synapse.storage.database import ( ) from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.util.id_generators import MultiWriterIdGenerator +from synapse.types import EventOrderings from synapse.util.caches.descriptors import cached if TYPE_CHECKING: @@ -50,6 +51,14 @@ class ThreadSubscription: """ +class AutomaticSubscriptionConflicted: + """ + Marker return value to signal that an automatic subscription was skipped, + because it conflicted with an unsubscription that we consider to have + been made later than the event causing the automatic subscription. + """ + + class ThreadSubscriptionsWorkerStore(CacheInvalidationWorkerStore): def __init__( self, @@ -91,6 +100,7 @@ class ThreadSubscriptionsWorkerStore(CacheInvalidationWorkerStore): self.get_subscription_for_thread.invalidate( (row.user_id, row.room_id, row.event_id) ) + self.get_subscribers_to_thread.invalidate((row.room_id, row.event_id)) super().process_replication_rows(stream_name, instance_name, token, rows) @@ -101,75 +111,196 @@ class ThreadSubscriptionsWorkerStore(CacheInvalidationWorkerStore): self._thread_subscriptions_id_gen.advance(instance_name, token) super().process_replication_position(stream_name, instance_name, token) + @staticmethod + def _should_skip_autosubscription_after_unsubscription( + *, + autosub: EventOrderings, + unsubscribed_at: EventOrderings, + ) -> bool: + """ + Returns whether an automatic subscription occurring *after* an unsubscription + should be skipped, because the unsubscription already 'acknowledges' the event + causing the automatic subscription (the cause event). + + To determine *after*, we use `stream_ordering` unless the event is backfilled + (negative `stream_ordering`) and fallback to topological ordering. + + Args: + autosub: the stream_ordering and topological_ordering of the cause event + unsubscribed_at: + the maximum stream ordering and the maximum topological ordering at the time of unsubscription + + Returns: + True if the automatic subscription should be skipped + """ + # For normal rooms, these two orderings should be positive, because + # they don't refer to a specific event but rather the maximum at the + # time of unsubscription. + # + # However, for rooms that have never been joined and that are being peeked at, + # we might not have a single non-backfilled event and therefore the stream + # ordering might be negative, so we don't assert this case. + assert unsubscribed_at.topological > 0 + + unsubscribed_at_backfilled = unsubscribed_at.stream < 0 + if ( + not unsubscribed_at_backfilled + and unsubscribed_at.stream >= autosub.stream > 0 + ): + # non-backfilled events: the unsubscription is later according to + # the stream + return True + + if autosub.stream < 0: + # the auto-subscription cause event was backfilled, so fall back to + # topological ordering + if unsubscribed_at.topological >= autosub.topological: + return True + + return False + async def subscribe_user_to_thread( - self, user_id: str, room_id: str, thread_root_event_id: str, *, automatic: bool - ) -> Optional[int]: + self, + user_id: str, + room_id: str, + thread_root_event_id: str, + *, + automatic_event_orderings: Optional[EventOrderings], + ) -> Optional[Union[int, AutomaticSubscriptionConflicted]]: """Updates a user's subscription settings for a specific thread root. If no change would be made to the subscription, does not produce any database change. + Case-by-case: + - if we already have an automatic subscription: + - new automatic subscriptions will be no-ops (no database write), + - new manual subscriptions will overwrite the automatic subscription + - if we already have a manual subscription: + we don't update (no database write) in either case, because: + - the existing manual subscription wins over a new automatic subscription request + - there would be no need to write a manual subscription because we already have one + Args: user_id: The ID of the user whose settings are being updated. room_id: The ID of the room the thread root belongs to. thread_root_event_id: The event ID of the thread root. - automatic: Whether the subscription was performed automatically by the user's client. - Only `False` will overwrite an existing value of automatic for a subscription row. + automatic_event_orderings: + Value depends on whether the subscription was performed automatically by the user's client. + For manual subscriptions: None. + For automatic subscriptions: the orderings of the event. Returns: - The stream ID for this update, if the update isn't no-opped. + If a subscription is made: (int) the stream ID for this update. + If a subscription already exists and did not need to be updated: None + If an automatic subscription conflicted with an unsubscription: AutomaticSubscriptionConflicted """ assert self._can_write_to_thread_subscriptions - def _subscribe_user_to_thread_txn(txn: LoggingTransaction) -> Optional[int]: - already_automatic = self.db_pool.simple_select_one_onecol_txn( - txn, - table="thread_subscriptions", - keyvalues={ - "user_id": user_id, - "event_id": thread_root_event_id, - "room_id": room_id, - "subscribed": True, - }, - retcol="automatic", - allow_none=True, - ) - - if already_automatic is None: - already_subscribed = False - already_automatic = True - else: - already_subscribed = True - # convert int (SQLite bool) to Python bool - already_automatic = bool(already_automatic) - - if already_subscribed and already_automatic == automatic: - # there is nothing we need to do here - return None - - stream_id = self._thread_subscriptions_id_gen.get_next_txn(txn) - - values: Dict[str, Optional[Union[bool, int, str]]] = { - "subscribed": True, - "stream_id": stream_id, - "instance_name": self._instance_name, - "automatic": already_automatic and automatic, - } - - self.db_pool.simple_upsert_txn( - txn, - table="thread_subscriptions", - keyvalues={ - "user_id": user_id, - "event_id": thread_root_event_id, - "room_id": room_id, - }, - values=values, - ) - + def _invalidate_subscription_caches(txn: LoggingTransaction) -> None: txn.call_after( self.get_subscription_for_thread.invalidate, (user_id, room_id, thread_root_event_id), ) + txn.call_after( + self.get_subscribers_to_thread.invalidate, + (room_id, thread_root_event_id), + ) + + def _subscribe_user_to_thread_txn( + txn: LoggingTransaction, + ) -> Optional[Union[int, AutomaticSubscriptionConflicted]]: + requested_automatic = automatic_event_orderings is not None + + row = self.db_pool.simple_select_one_txn( + txn, + table="thread_subscriptions", + keyvalues={ + "user_id": user_id, + "event_id": thread_root_event_id, + "room_id": room_id, + }, + retcols=( + "subscribed", + "automatic", + "unsubscribed_at_stream_ordering", + "unsubscribed_at_topological_ordering", + ), + allow_none=True, + ) + + if row is None: + # We have never subscribed before, simply insert the row and finish + stream_id = self._thread_subscriptions_id_gen.get_next_txn(txn) + self.db_pool.simple_insert_txn( + txn, + table="thread_subscriptions", + values={ + "user_id": user_id, + "event_id": thread_root_event_id, + "room_id": room_id, + "subscribed": True, + "stream_id": stream_id, + "instance_name": self._instance_name, + "automatic": requested_automatic, + "unsubscribed_at_stream_ordering": None, + "unsubscribed_at_topological_ordering": None, + }, + ) + _invalidate_subscription_caches(txn) + return stream_id + + # we already have either a subscription or a prior unsubscription here + ( + subscribed, + already_automatic, + unsubscribed_at_stream_ordering, + unsubscribed_at_topological_ordering, + ) = row + + if subscribed and (not already_automatic or requested_automatic): + # we are already subscribed and the current subscription state + # is good enough (either we already have a manual subscription, + # or we requested an automatic subscription) + # In that case, nothing to change here. + # (See docstring for case-by-case explanation) + return None + + if not subscribed and requested_automatic: + assert automatic_event_orderings is not None + # we previously unsubscribed and we are now automatically subscribing + # Check whether the new autosubscription should be skipped + if ThreadSubscriptionsWorkerStore._should_skip_autosubscription_after_unsubscription( + autosub=automatic_event_orderings, + unsubscribed_at=EventOrderings( + unsubscribed_at_stream_ordering, + unsubscribed_at_topological_ordering, + ), + ): + # skip the subscription + return AutomaticSubscriptionConflicted() + + # At this point: we have now finished checking that we need to make + # a subscription, updating the current row. + + stream_id = self._thread_subscriptions_id_gen.get_next_txn(txn) + self.db_pool.simple_update_txn( + txn, + table="thread_subscriptions", + keyvalues={ + "user_id": user_id, + "event_id": thread_root_event_id, + "room_id": room_id, + }, + updatevalues={ + "subscribed": True, + "stream_id": stream_id, + "instance_name": self._instance_name, + "automatic": requested_automatic, + "unsubscribed_at_stream_ordering": None, + "unsubscribed_at_topological_ordering": None, + }, + ) + _invalidate_subscription_caches(txn) return stream_id @@ -214,6 +345,21 @@ class ThreadSubscriptionsWorkerStore(CacheInvalidationWorkerStore): stream_id = self._thread_subscriptions_id_gen.get_next_txn(txn) + # Find the maximum stream ordering and topological ordering of the room, + # which we then store against this unsubscription so we can skip future + # automatic subscriptions that are caused by an event logically earlier + # than this unsubscription. + txn.execute( + """ + SELECT MAX(stream_ordering) AS mso, MAX(topological_ordering) AS mto FROM events + WHERE room_id = ? + """, + (room_id,), + ) + ord_row = txn.fetchone() + assert ord_row is not None + max_stream_ordering, max_topological_ordering = ord_row + self.db_pool.simple_update_txn( txn, table="thread_subscriptions", @@ -227,6 +373,8 @@ class ThreadSubscriptionsWorkerStore(CacheInvalidationWorkerStore): "subscribed": False, "stream_id": stream_id, "instance_name": self._instance_name, + "unsubscribed_at_stream_ordering": max_stream_ordering, + "unsubscribed_at_topological_ordering": max_topological_ordering, }, ) @@ -234,6 +382,10 @@ class ThreadSubscriptionsWorkerStore(CacheInvalidationWorkerStore): self.get_subscription_for_thread.invalidate, (user_id, room_id, thread_root_event_id), ) + txn.call_after( + self.get_subscribers_to_thread.invalidate, + (room_id, thread_root_event_id), + ) return stream_id @@ -246,7 +398,9 @@ class ThreadSubscriptionsWorkerStore(CacheInvalidationWorkerStore): Purge all subscriptions for the user. The fact that subscriptions have been purged will not be streamed; all stream rows for the user will in fact be removed. - This is intended only for dealing with user deactivation. + + This must only be used for user deactivation, + because it does not invalidate the `subscribers_to_thread` cache. """ def _purge_thread_subscription_settings_for_user_txn( @@ -307,6 +461,42 @@ class ThreadSubscriptionsWorkerStore(CacheInvalidationWorkerStore): return ThreadSubscription(automatic=automatic) + # max_entries=100 rationale: + # this returns a potentially large datastructure + # (since each entry contains a set which contains a potentially large number of user IDs), + # whereas the default of 10'000 entries for @cached feels more + # suitable for very small cache entries. + # + # Overall, when bearing in mind the usual profile of a small community-server or company-server + # (where cache tuning hasn't been done, so we're in out-of-box configuration), it is very + # unlikely we would benefit from keeping hot the subscribers for as many as 100 threads, + # since it's unlikely that so many threads will be active in a short span of time on a small homeserver. + # It feels that medium servers will probably also not exhaust this limit. + # Larger homeservers are more likely to be carefully tuned, either with a larger global cache factor + # or carefully following the usage patterns & cache metrics. + # Finally, the query is not so intensive that computing it every time is a huge deal, but given people + # often send messages back-to-back in the same thread it seems like it would offer a mild benefit. + @cached(max_entries=100) + async def get_subscribers_to_thread( + self, room_id: str, thread_root_event_id: str + ) -> FrozenSet[str]: + """ + Returns: + the set of user_ids for local users who are subscribed to the given thread. + """ + return frozenset( + await self.db_pool.simple_select_onecol( + table="thread_subscriptions", + keyvalues={ + "room_id": room_id, + "event_id": thread_root_event_id, + "subscribed": True, + }, + retcol="user_id", + desc="get_subscribers_to_thread", + ) + ) + def get_max_thread_subscriptions_stream_id(self) -> int: """Get the current maximum stream_id for thread subscriptions. @@ -316,7 +506,7 @@ class ThreadSubscriptionsWorkerStore(CacheInvalidationWorkerStore): return self._thread_subscriptions_id_gen.get_current_token() async def get_updated_thread_subscriptions( - self, from_id: int, to_id: int, limit: int + self, *, from_id: int, to_id: int, limit: int ) -> List[Tuple[int, str, str, str]]: """Get updates to thread subscriptions between two stream IDs. @@ -349,7 +539,7 @@ class ThreadSubscriptionsWorkerStore(CacheInvalidationWorkerStore): ) async def get_updated_thread_subscriptions_for_user( - self, user_id: str, from_id: int, to_id: int, limit: int + self, user_id: str, *, from_id: int, to_id: int, limit: int ) -> List[Tuple[int, str, str]]: """Get updates to thread subscriptions for a specific user. diff --git a/synapse/storage/schema/main/delta/92/09_thread_subscriptions_update.sql b/synapse/storage/schema/main/delta/92/09_thread_subscriptions_update.sql new file mode 100644 index 000000000..03b8a1a63 --- /dev/null +++ b/synapse/storage/schema/main/delta/92/09_thread_subscriptions_update.sql @@ -0,0 +1,20 @@ +-- +-- This file is licensed under the Affero General Public License (AGPL) version 3. +-- +-- Copyright (C) 2025 New Vector, Ltd +-- +-- This program is free software: you can redistribute it and/or modify +-- it under the terms of the GNU Affero General Public License as +-- published by the Free Software Foundation, either version 3 of the +-- License, or (at your option) any later version. +-- +-- See the GNU Affero General Public License for more details: +-- . + +-- The maximum stream_ordering in the room when the unsubscription was made. +ALTER TABLE thread_subscriptions + ADD COLUMN unsubscribed_at_stream_ordering BIGINT; + +-- The maximum topological_ordering in the room when the unsubscription was made. +ALTER TABLE thread_subscriptions + ADD COLUMN unsubscribed_at_topological_ordering BIGINT; diff --git a/synapse/storage/schema/main/delta/92/09_thread_subscriptions_update.sql.postgres b/synapse/storage/schema/main/delta/92/09_thread_subscriptions_update.sql.postgres new file mode 100644 index 000000000..fc5d555db --- /dev/null +++ b/synapse/storage/schema/main/delta/92/09_thread_subscriptions_update.sql.postgres @@ -0,0 +1,18 @@ +-- +-- This file is licensed under the Affero General Public License (AGPL) version 3. +-- +-- Copyright (C) 2025 New Vector, Ltd +-- +-- This program is free software: you can redistribute it and/or modify +-- it under the terms of the GNU Affero General Public License as +-- published by the Free Software Foundation, either version 3 of the +-- License, or (at your option) any later version. +-- +-- See the GNU Affero General Public License for more details: +-- . + +COMMENT ON COLUMN thread_subscriptions.unsubscribed_at_stream_ordering IS + $$The maximum stream_ordering in the room when the unsubscription was made.$$; + +COMMENT ON COLUMN thread_subscriptions.unsubscribed_at_topological_ordering IS + $$The maximum topological_ordering in the room when the unsubscription was made.$$; diff --git a/synapse/synapse_rust/push.pyi b/synapse/synapse_rust/push.pyi index 3f317c328..a3e12ad64 100644 --- a/synapse/synapse_rust/push.pyi +++ b/synapse/synapse_rust/push.pyi @@ -49,6 +49,7 @@ class FilteredPushRules: msc3664_enabled: bool, msc4028_push_encrypted_events: bool, msc4210_enabled: bool, + msc4306_enabled: bool, ): ... def rules(self) -> Collection[Tuple[PushRule, bool]]: ... @@ -67,13 +68,19 @@ class PushRuleEvaluator: room_version_feature_flags: Tuple[str, ...], msc3931_enabled: bool, msc4210_enabled: bool, + msc4306_enabled: bool, ): ... def run( self, push_rules: FilteredPushRules, user_id: Optional[str], display_name: Optional[str], + msc4306_thread_subscription_state: Optional[bool], ) -> Collection[Union[Mapping, str]]: ... def matches( - self, condition: JsonDict, user_id: Optional[str], display_name: Optional[str] + self, + condition: JsonDict, + user_id: Optional[str], + display_name: Optional[str], + msc4306_thread_subscription_state: Optional[bool] = None, ) -> bool: ... diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py index 3b516fce3..0ea3a0a4a 100644 --- a/synapse/types/__init__.py +++ b/synapse/types/__init__.py @@ -73,6 +73,7 @@ if TYPE_CHECKING: from typing_extensions import Self from synapse.appservice.api import ApplicationService + from synapse.events import EventBase from synapse.storage.databases.main import DataStore, PurgeEventsStore from synapse.storage.databases.main.appservice import ApplicationServiceWorkerStore from synapse.storage.util.id_generators import MultiWriterIdGenerator @@ -1464,3 +1465,31 @@ class ScheduledTask: result: Optional[JsonMapping] # Optional error that should be assigned a value when the status is FAILED error: Optional[str] + + +@attr.s(auto_attribs=True, frozen=True, slots=True) +class EventOrderings: + stream: int + """ + The stream_ordering of the event. + Negative numbers mean the event was backfilled. + """ + + topological: int + """ + The topological_ordering of the event. + Currently this is equivalent to the `depth` attributes of + the PDU. + """ + + @staticmethod + def from_event(event: "EventBase") -> "EventOrderings": + """ + Get the orderings from an event. + + Preconditions: + - the event must have been persisted (otherwise it won't have a stream ordering) + """ + stream = event.internal_metadata.stream_ordering + assert stream is not None + return EventOrderings(stream, event.depth) diff --git a/synapse/util/pydantic_models.py b/synapse/util/pydantic_models.py index ba9e7bb7d..488070950 100644 --- a/synapse/util/pydantic_models.py +++ b/synapse/util/pydantic_models.py @@ -13,7 +13,11 @@ # # -from synapse._pydantic_compat import BaseModel, Extra +import re +from typing import Any, Callable, Generator + +from synapse._pydantic_compat import BaseModel, Extra, StrictStr +from synapse.types import EventID class ParseModel(BaseModel): @@ -37,3 +41,43 @@ class ParseModel(BaseModel): extra = Extra.ignore # By default, don't allow fields to be reassigned after parsing. allow_mutation = False + + +class AnyEventId(StrictStr): + """ + A validator for strings that need to be an Event ID. + + Accepts any valid grammar of Event ID from any room version. + """ + + EVENT_ID_HASH_ROOM_VERSION_3_PLUS = re.compile( + r"^([a-zA-Z0-9-_]{43}|[a-zA-Z0-9+/]{43})$" + ) + + @classmethod + def __get_validators__(cls) -> Generator[Callable[..., Any], Any, Any]: + yield from super().__get_validators__() # type: ignore + yield cls.validate_event_id + + @classmethod + def validate_event_id(cls, value: str) -> str: + if not value.startswith("$"): + raise ValueError("Event ID must start with `$`") + + if ":" in value: + # Room versions 1 and 2 + EventID.from_string(value) # throws on fail + else: + # Room versions 3+: event ID is $ + a base64 sha256 hash + # Room version 3 is base64, 4+ are base64Url + # In both cases, the base64 is unpadded. + # refs: + # - https://spec.matrix.org/v1.15/rooms/v3/ e.g. $acR1l0raoZnm60CBwAVgqbZqoO/mYU81xysh1u7XcJk + # - https://spec.matrix.org/v1.15/rooms/v4/ e.g. $Rqnc-F-dvnEYJTyHq_iKxU2bZ1CI92-kuZq3a5lr5Zg + b64_hash = value[1:] + if cls.EVENT_ID_HASH_ROOM_VERSION_3_PLUS.fullmatch(b64_hash) is None: + raise ValueError( + "Event ID must either have a domain part or be a valid hash" + ) + + return value diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py index b3c65676c..f5e592d80 100644 --- a/synapse/util/ratelimitutils.py +++ b/synapse/util/ratelimitutils.py @@ -131,31 +131,27 @@ def _get_counts_from_rate_limiter_instance( # We track the number of affected hosts per time-period so we can # differentiate one really noisy homeserver from a general # ratelimit tuning problem across the federation. -sleep_affected_hosts_gauge = LaterGauge( +LaterGauge( name="synapse_rate_limit_sleep_affected_hosts", desc="Number of hosts that had requests put to sleep", labelnames=["rate_limiter_name", SERVER_NAME_LABEL], -) -sleep_affected_hosts_gauge.register_hook( - lambda: _get_counts_from_rate_limiter_instance( + caller=lambda: _get_counts_from_rate_limiter_instance( lambda rate_limiter_instance: sum( ratelimiter.should_sleep() for ratelimiter in rate_limiter_instance.ratelimiters.values() ) - ) + ), ) -reject_affected_hosts_gauge = LaterGauge( +LaterGauge( name="synapse_rate_limit_reject_affected_hosts", desc="Number of hosts that had requests rejected", labelnames=["rate_limiter_name", SERVER_NAME_LABEL], -) -reject_affected_hosts_gauge.register_hook( - lambda: _get_counts_from_rate_limiter_instance( + caller=lambda: _get_counts_from_rate_limiter_instance( lambda rate_limiter_instance: sum( ratelimiter.should_reject() for ratelimiter in rate_limiter_instance.ratelimiters.values() ) - ) + ), ) diff --git a/synapse/util/task_scheduler.py b/synapse/util/task_scheduler.py index 904f99fa4..fdcacdf12 100644 --- a/synapse/util/task_scheduler.py +++ b/synapse/util/task_scheduler.py @@ -44,13 +44,6 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -running_tasks_gauge = LaterGauge( - name="synapse_scheduler_running_tasks", - desc="The number of concurrent running tasks handled by the TaskScheduler", - labelnames=[SERVER_NAME_LABEL], -) - - class TaskScheduler: """ This is a simple task scheduler designed for resumable tasks. Normally, @@ -137,8 +130,11 @@ class TaskScheduler: TaskScheduler.SCHEDULE_INTERVAL_MS, ) - running_tasks_gauge.register_hook( - lambda: {(self.server_name,): len(self._running_tasks)} + LaterGauge( + name="synapse_scheduler_running_tasks", + desc="The number of concurrent running tasks handled by the TaskScheduler", + labelnames=[SERVER_NAME_LABEL], + caller=lambda: {(self.server_name,): len(self._running_tasks)}, ) def register_action( diff --git a/tests/metrics/test_metrics.py b/tests/metrics/test_metrics.py index 5a3c3c1c4..61874564a 100644 --- a/tests/metrics/test_metrics.py +++ b/tests/metrics/test_metrics.py @@ -22,13 +22,7 @@ from typing import Dict, Protocol, Tuple from prometheus_client.core import Sample -from synapse.metrics import ( - REGISTRY, - SERVER_NAME_LABEL, - InFlightGauge, - LaterGauge, - generate_latest, -) +from synapse.metrics import REGISTRY, InFlightGauge, generate_latest from synapse.util.caches.deferred_cache import DeferredCache from tests import unittest @@ -291,42 +285,6 @@ class CacheMetricsTests(unittest.HomeserverTestCase): self.assertEqual(hs2_cache_max_size_metric_value, "777.0") -class LaterGaugeTests(unittest.HomeserverTestCase): - def test_later_gauge_multiple_servers(self) -> None: - """ - Test that LaterGauge metrics are reported correctly across multiple servers. We - will have an metrics entry for each homeserver that is labeled with the - `server_name` label. - """ - later_gauge = LaterGauge( - name="foo", - desc="", - labelnames=[SERVER_NAME_LABEL], - ) - later_gauge.register_hook(lambda: {("hs1",): 1}) - later_gauge.register_hook(lambda: {("hs2",): 2}) - - metrics_map = get_latest_metrics() - - # Find the metrics for the caches from both homeservers - hs1_metric = 'foo{server_name="hs1"}' - hs1_metric_value = metrics_map.get(hs1_metric) - self.assertIsNotNone( - hs1_metric_value, - f"Missing metric {hs1_metric} in cache metrics {metrics_map}", - ) - hs2_metric = 'foo{server_name="hs2"}' - hs2_metric_value = metrics_map.get(hs2_metric) - self.assertIsNotNone( - hs2_metric_value, - f"Missing metric {hs2_metric} in cache metrics {metrics_map}", - ) - - # Sanity check the metric values - self.assertEqual(hs1_metric_value, "1.0") - self.assertEqual(hs2_metric_value, "2.0") - - def get_latest_metrics() -> Dict[str, str]: """ Collect the latest metrics from the registry and parse them into an easy to use map. diff --git a/tests/push/test_bulk_push_rule_evaluator.py b/tests/push/test_bulk_push_rule_evaluator.py index 6c8c3a09d..fad5c7aff 100644 --- a/tests/push/test_bulk_push_rule_evaluator.py +++ b/tests/push/test_bulk_push_rule_evaluator.py @@ -26,7 +26,7 @@ from parameterized import parameterized from twisted.internet.testing import MemoryReactor -from synapse.api.constants import EventContentFields, RelationTypes +from synapse.api.constants import EventContentFields, EventTypes, RelationTypes from synapse.api.room_versions import RoomVersions from synapse.push.bulk_push_rule_evaluator import BulkPushRuleEvaluator from synapse.rest import admin @@ -206,7 +206,10 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase): bulk_evaluator._action_for_event_by_user.assert_not_called() def _create_and_process( - self, bulk_evaluator: BulkPushRuleEvaluator, content: Optional[JsonDict] = None + self, + bulk_evaluator: BulkPushRuleEvaluator, + content: Optional[JsonDict] = None, + type: str = "test", ) -> bool: """Returns true iff the `mentions` trigger an event push action.""" # Create a new message event which should cause a notification. @@ -214,7 +217,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase): self.event_creation_handler.create_event( self.requester, { - "type": "test", + "type": type, "room_id": self.room_id, "content": content or {}, "sender": f"@bob:{self.hs.hostname}", @@ -446,3 +449,73 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase): }, ) ) + + @override_config({"experimental_features": {"msc4306_enabled": True}}) + def test_thread_subscriptions(self) -> None: + bulk_evaluator = BulkPushRuleEvaluator(self.hs) + (thread_root_id,) = self.helper.send_messages(self.room_id, 1, tok=self.token) + + self.assertFalse( + self._create_and_process( + bulk_evaluator, + { + "msgtype": "m.text", + "body": "test message before subscription", + "m.relates_to": { + "rel_type": RelationTypes.THREAD, + "event_id": thread_root_id, + }, + }, + type=EventTypes.Message, + ) + ) + + self.get_success( + self.hs.get_datastores().main.subscribe_user_to_thread( + self.alice, + self.room_id, + thread_root_id, + automatic_event_orderings=None, + ) + ) + + self.assertTrue( + self._create_and_process( + bulk_evaluator, + { + "msgtype": "m.text", + "body": "test message after subscription", + "m.relates_to": { + "rel_type": RelationTypes.THREAD, + "event_id": thread_root_id, + }, + }, + type="m.room.message", + ) + ) + + def test_with_disabled_thread_subscriptions(self) -> None: + """ + Test what happens with threaded events when MSC4306 is disabled. + + FUTURE: If MSC4306 becomes enabled-by-default/accepted, this test is to be removed. + """ + bulk_evaluator = BulkPushRuleEvaluator(self.hs) + (thread_root_id,) = self.helper.send_messages(self.room_id, 1, tok=self.token) + + # When MSC4306 is not enabled, a threaded message generates a notification + # by default. + self.assertTrue( + self._create_and_process( + bulk_evaluator, + { + "msgtype": "m.text", + "body": "test message before subscription", + "m.relates_to": { + "rel_type": RelationTypes.THREAD, + "event_id": thread_root_id, + }, + }, + type="m.room.message", + ) + ) diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py index 8789c2f4c..3a351acff 100644 --- a/tests/push/test_push_rule_evaluator.py +++ b/tests/push/test_push_rule_evaluator.py @@ -150,6 +150,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): *, related_events: Optional[JsonDict] = None, msc4210: bool = False, + msc4306: bool = False, ) -> PushRuleEvaluator: event = FrozenEvent( { @@ -176,6 +177,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): room_version_feature_flags=event.room_version.msc3931_push_features, msc3931_enabled=True, msc4210_enabled=msc4210, + msc4306_enabled=msc4306, ) def test_display_name(self) -> None: @@ -806,6 +808,112 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): ) ) + def test_thread_subscription_subscribed(self) -> None: + """ + Test MSC4306 thread subscription push rules against an event in a subscribed thread. + """ + evaluator = self._get_evaluator( + { + "msgtype": "m.text", + "body": "Squawk", + "m.relates_to": { + "event_id": "$threadroot", + "rel_type": "m.thread", + }, + }, + msc4306=True, + ) + self.assertTrue( + evaluator.matches( + { + "kind": "io.element.msc4306.thread_subscription", + "subscribed": True, + }, + None, + None, + msc4306_thread_subscription_state=True, + ) + ) + self.assertFalse( + evaluator.matches( + { + "kind": "io.element.msc4306.thread_subscription", + "subscribed": False, + }, + None, + None, + msc4306_thread_subscription_state=True, + ) + ) + + def test_thread_subscription_unsubscribed(self) -> None: + """ + Test MSC4306 thread subscription push rules against an event in an unsubscribed thread. + """ + evaluator = self._get_evaluator( + { + "msgtype": "m.text", + "body": "Squawk", + "m.relates_to": { + "event_id": "$threadroot", + "rel_type": "m.thread", + }, + }, + msc4306=True, + ) + self.assertFalse( + evaluator.matches( + { + "kind": "io.element.msc4306.thread_subscription", + "subscribed": True, + }, + None, + None, + msc4306_thread_subscription_state=False, + ) + ) + self.assertTrue( + evaluator.matches( + { + "kind": "io.element.msc4306.thread_subscription", + "subscribed": False, + }, + None, + None, + msc4306_thread_subscription_state=False, + ) + ) + + def test_thread_subscription_unthreaded(self) -> None: + """ + Test MSC4306 thread subscription push rules against an unthreaded event. + """ + evaluator = self._get_evaluator( + {"msgtype": "m.text", "body": "Squawk"}, msc4306=True + ) + self.assertFalse( + evaluator.matches( + { + "kind": "io.element.msc4306.thread_subscription", + "subscribed": True, + }, + None, + None, + msc4306_thread_subscription_state=None, + ) + ) + self.assertFalse( + evaluator.matches( + { + "kind": "io.element.msc4306.thread_subscription", + "subscribed": False, + }, + None, + None, + msc4306_thread_subscription_state=None, + ) + ) + class TestBulkPushRuleEvaluator(unittest.HomeserverTestCase): """Tests for the bulk push rule evaluator""" diff --git a/tests/replication/tcp/streams/test_thread_subscriptions.py b/tests/replication/tcp/streams/test_thread_subscriptions.py index 035f06187..7283aa851 100644 --- a/tests/replication/tcp/streams/test_thread_subscriptions.py +++ b/tests/replication/tcp/streams/test_thread_subscriptions.py @@ -62,7 +62,7 @@ class ThreadSubscriptionsStreamTestCase(BaseStreamTestCase): "@test_user:example.org", room_id, thread_root_id, - automatic=True, + automatic_event_orderings=None, ) ) updates.append(thread_root_id) @@ -75,7 +75,7 @@ class ThreadSubscriptionsStreamTestCase(BaseStreamTestCase): "@test_user:example.org", other_room_id, other_thread_root_id, - automatic=False, + automatic_event_orderings=None, ) ) @@ -124,7 +124,7 @@ class ThreadSubscriptionsStreamTestCase(BaseStreamTestCase): for user_id in users: self.get_success( store.subscribe_user_to_thread( - user_id, room_id, thread_root_id, automatic=True + user_id, room_id, thread_root_id, automatic_event_orderings=None ) ) diff --git a/tests/rest/client/test_thread_subscriptions.py b/tests/rest/client/test_thread_subscriptions.py index 624cb9c72..3fbf3c5bf 100644 --- a/tests/rest/client/test_thread_subscriptions.py +++ b/tests/rest/client/test_thread_subscriptions.py @@ -15,6 +15,7 @@ from http import HTTPStatus from twisted.internet.testing import MemoryReactor +from synapse.api.errors import Codes from synapse.rest import admin from synapse.rest.client import login, profile, room, thread_subscriptions from synapse.server import HomeServer @@ -49,15 +50,16 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase): # Create a room and send a message to use as a thread root self.room_id = self.helper.create_room_as(self.user_id, tok=self.token) self.helper.join(self.room_id, self.other_user_id, tok=self.other_token) - response = self.helper.send(self.room_id, body="Root message", tok=self.token) - self.root_event_id = response["event_id"] + (self.root_event_id,) = self.helper.send_messages( + self.room_id, 1, tok=self.token + ) # Send a message in the thread - self.helper.send_event( - room_id=self.room_id, - type="m.room.message", - content={ - "body": "Thread message", + self.threaded_events = self.helper.send_messages( + self.room_id, + 2, + content_fn=lambda idx: { + "body": f"Thread message {idx}", "msgtype": "m.text", "m.relates_to": { "rel_type": "m.thread", @@ -106,9 +108,7 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase): channel = self.make_request( "PUT", f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription", - { - "automatic": False, - }, + {}, access_token=self.token, ) self.assertEqual(channel.code, HTTPStatus.OK) @@ -127,7 +127,7 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase): channel = self.make_request( "PUT", f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription", - {"automatic": True}, + {"automatic": self.threaded_events[0]}, access_token=self.token, ) self.assertEqual(channel.code, HTTPStatus.OK) @@ -148,11 +148,11 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase): "PUT", f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription", { - "automatic": True, + "automatic": self.threaded_events[0], }, access_token=self.token, ) - self.assertEqual(channel.code, HTTPStatus.OK) + self.assertEqual(channel.code, HTTPStatus.OK, channel.text_body) # Assert the subscription was saved channel = self.make_request( @@ -167,7 +167,7 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase): channel = self.make_request( "PUT", f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription", - {"automatic": False}, + {}, access_token=self.token, ) self.assertEqual(channel.code, HTTPStatus.OK) @@ -187,7 +187,7 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase): "PUT", f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription", { - "automatic": True, + "automatic": self.threaded_events[0], }, access_token=self.token, ) @@ -202,7 +202,6 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, HTTPStatus.OK) self.assertEqual(channel.json_body, {"automatic": True}) - # Now also register a manual subscription channel = self.make_request( "DELETE", f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription", @@ -210,7 +209,6 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.code, HTTPStatus.OK) - # Assert the manual subscription was not overridden channel = self.make_request( "GET", f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription", @@ -224,7 +222,7 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase): channel = self.make_request( "PUT", f"{PREFIX}/{self.room_id}/thread/$nonexistent:example.org/subscription", - {"automatic": True}, + {}, access_token=self.token, ) self.assertEqual(channel.code, HTTPStatus.NOT_FOUND) @@ -238,7 +236,7 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase): channel = self.make_request( "PUT", f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription", - {"automatic": True}, + {}, access_token=no_access_token, ) self.assertEqual(channel.code, HTTPStatus.NOT_FOUND) @@ -249,8 +247,105 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase): channel = self.make_request( "PUT", f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription", - # non-boolean `automatic` - {"automatic": "true"}, + # non-Event ID `automatic` + {"automatic": True}, access_token=self.token, ) self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST) + + channel = self.make_request( + "PUT", + f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription", + # non-Event ID `automatic` + {"automatic": "$malformedEventId"}, + access_token=self.token, + ) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST) + + def test_auto_subscribe_cause_event_not_in_thread(self) -> None: + """ + Test making an automatic subscription, where the cause event is not + actually in the thread. + This is an error. + """ + (unrelated_event_id,) = self.helper.send_messages( + self.room_id, 1, tok=self.token + ) + channel = self.make_request( + "PUT", + f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription", + {"automatic": unrelated_event_id}, + access_token=self.token, + ) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.text_body) + self.assertEqual(channel.json_body["errcode"], Codes.MSC4306_NOT_IN_THREAD) + + def test_auto_resubscription_conflict(self) -> None: + """ + Test that an automatic subscription that conflicts with an unsubscription + is skipped. + """ + # Reuse the test that subscribes and unsubscribes + self.test_unsubscribe() + + # Now no matter which event we present as the cause of an automatic subscription, + # the automatic subscription is skipped. + # This is because the unsubscription happened after all of the events. + for event in self.threaded_events: + channel = self.make_request( + "PUT", + f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription", + { + "automatic": event, + }, + access_token=self.token, + ) + self.assertEqual(channel.code, HTTPStatus.CONFLICT, channel.text_body) + self.assertEqual( + channel.json_body["errcode"], + Codes.MSC4306_CONFLICTING_UNSUBSCRIPTION, + channel.text_body, + ) + + # Check the subscription was not made + channel = self.make_request( + "GET", + f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription", + access_token=self.token, + ) + self.assertEqual(channel.code, HTTPStatus.NOT_FOUND) + + # But if a new event is sent after the unsubscription took place, + # that one can be used for an automatic subscription + (later_event_id,) = self.helper.send_messages( + self.room_id, + 1, + content_fn=lambda _: { + "body": "Thread message after unsubscription", + "msgtype": "m.text", + "m.relates_to": { + "rel_type": "m.thread", + "event_id": self.root_event_id, + }, + }, + tok=self.token, + ) + + channel = self.make_request( + "PUT", + f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription", + { + "automatic": later_event_id, + }, + access_token=self.token, + ) + self.assertEqual(channel.code, HTTPStatus.OK, channel.text_body) + + # Check the subscription was made + channel = self.make_request( + "GET", + f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription", + access_token=self.token, + ) + self.assertEqual(channel.code, HTTPStatus.OK) + self.assertEqual(channel.json_body, {"automatic": True}) diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py index 02566709b..bb214759d 100644 --- a/tests/rest/client/utils.py +++ b/tests/rest/client/utils.py @@ -29,12 +29,14 @@ from http import HTTPStatus from typing import ( Any, AnyStr, + Callable, Dict, Iterable, Literal, Mapping, MutableMapping, Optional, + Sequence, Tuple, overload, ) @@ -45,7 +47,7 @@ import attr from twisted.internet.testing import MemoryReactorClock from twisted.web.server import Site -from synapse.api.constants import Membership, ReceiptTypes +from synapse.api.constants import EventTypes, Membership, ReceiptTypes from synapse.api.errors import Codes from synapse.server import HomeServer from synapse.types import JsonDict @@ -394,6 +396,32 @@ class RestHelper: custom_headers=custom_headers, ) + def send_messages( + self, + room_id: str, + num_events: int, + content_fn: Callable[[int], JsonDict] = lambda idx: { + "msgtype": "m.text", + "body": f"Test event {idx}", + }, + tok: Optional[str] = None, + ) -> Sequence[str]: + """ + Helper to send a handful of sequential events and return their event IDs as a sequence. + """ + event_ids = [] + + for event_index in range(num_events): + response = self.send_event( + room_id, + EventTypes.Message, + content_fn(event_index), + tok=tok, + ) + event_ids.append(response["event_id"]) + + return event_ids + def send_event( self, room_id: str, diff --git a/tests/storage/test_thread_subscriptions.py b/tests/storage/test_thread_subscriptions.py index 69317d5b0..2a5c440cf 100644 --- a/tests/storage/test_thread_subscriptions.py +++ b/tests/storage/test_thread_subscriptions.py @@ -12,13 +12,18 @@ # . # -from typing import Optional +from typing import Optional, Union from twisted.internet.testing import MemoryReactor from synapse.server import HomeServer from synapse.storage.database import LoggingTransaction +from synapse.storage.databases.main.thread_subscriptions import ( + AutomaticSubscriptionConflicted, + ThreadSubscriptionsWorkerStore, +) from synapse.storage.engines.sqlite import Sqlite3Engine +from synapse.types import EventOrderings from synapse.util import Clock from tests import unittest @@ -97,10 +102,10 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase): self, thread_root_id: str, *, - automatic: bool, + automatic_event_orderings: Optional[EventOrderings], room_id: Optional[str] = None, user_id: Optional[str] = None, - ) -> Optional[int]: + ) -> Optional[Union[int, AutomaticSubscriptionConflicted]]: if user_id is None: user_id = self.user_id @@ -112,7 +117,7 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase): user_id, room_id, thread_root_id, - automatic=automatic, + automatic_event_orderings=automatic_event_orderings, ) ) @@ -149,7 +154,7 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase): # Subscribe self._subscribe( self.thread_root_id, - automatic=True, + automatic_event_orderings=EventOrderings(1, 1), ) # Assert subscription went through @@ -164,7 +169,7 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase): # Now make it a manual subscription self._subscribe( self.thread_root_id, - automatic=False, + automatic_event_orderings=None, ) # Assert the manual subscription overrode the automatic one @@ -178,8 +183,10 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase): def test_purge_thread_subscriptions_for_user(self) -> None: """Test purging all thread subscription settings for a user.""" # Set subscription settings for multiple threads - self._subscribe(self.thread_root_id, automatic=True) - self._subscribe(self.other_thread_root_id, automatic=False) + self._subscribe( + self.thread_root_id, automatic_event_orderings=EventOrderings(1, 1) + ) + self._subscribe(self.other_thread_root_id, automatic_event_orderings=None) subscriptions = self.get_success( self.store.get_updated_thread_subscriptions_for_user( @@ -217,20 +224,32 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase): def test_get_updated_thread_subscriptions(self) -> None: """Test getting updated thread subscriptions since a stream ID.""" - stream_id1 = self._subscribe(self.thread_root_id, automatic=False) - stream_id2 = self._subscribe(self.other_thread_root_id, automatic=True) - assert stream_id1 is not None - assert stream_id2 is not None + stream_id1 = self._subscribe( + self.thread_root_id, automatic_event_orderings=EventOrderings(1, 1) + ) + stream_id2 = self._subscribe( + self.other_thread_root_id, automatic_event_orderings=EventOrderings(2, 2) + ) + assert stream_id1 is not None and not isinstance( + stream_id1, AutomaticSubscriptionConflicted + ) + assert stream_id2 is not None and not isinstance( + stream_id2, AutomaticSubscriptionConflicted + ) # Get updates since initial ID (should include both changes) updates = self.get_success( - self.store.get_updated_thread_subscriptions(0, stream_id2, 10) + self.store.get_updated_thread_subscriptions( + from_id=0, to_id=stream_id2, limit=10 + ) ) self.assertEqual(len(updates), 2) # Get updates since first change (should include only the second change) updates = self.get_success( - self.store.get_updated_thread_subscriptions(stream_id1, stream_id2, 10) + self.store.get_updated_thread_subscriptions( + from_id=stream_id1, to_id=stream_id2, limit=10 + ) ) self.assertEqual( updates, @@ -242,21 +261,27 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase): other_user_id = "@other_user:test" # Set thread subscription for main user - stream_id1 = self._subscribe(self.thread_root_id, automatic=True) - assert stream_id1 is not None + stream_id1 = self._subscribe( + self.thread_root_id, automatic_event_orderings=EventOrderings(1, 1) + ) + assert stream_id1 is not None and not isinstance( + stream_id1, AutomaticSubscriptionConflicted + ) # Set thread subscription for other user stream_id2 = self._subscribe( self.other_thread_root_id, - automatic=True, + automatic_event_orderings=EventOrderings(1, 1), user_id=other_user_id, ) - assert stream_id2 is not None + assert stream_id2 is not None and not isinstance( + stream_id2, AutomaticSubscriptionConflicted + ) # Get updates for main user updates = self.get_success( self.store.get_updated_thread_subscriptions_for_user( - self.user_id, 0, stream_id2, 10 + self.user_id, from_id=0, to_id=stream_id2, limit=10 ) ) self.assertEqual(updates, [(stream_id1, self.room_id, self.thread_root_id)]) @@ -264,9 +289,80 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase): # Get updates for other user updates = self.get_success( self.store.get_updated_thread_subscriptions_for_user( - other_user_id, 0, max(stream_id1, stream_id2), 10 + other_user_id, from_id=0, to_id=max(stream_id1, stream_id2), limit=10 ) ) self.assertEqual( updates, [(stream_id2, self.room_id, self.other_thread_root_id)] ) + + def test_should_skip_autosubscription_after_unsubscription(self) -> None: + """ + Tests the comparison logic for whether an autoscription should be skipped + due to a chronologically earlier but logically later unsubscription. + """ + + func = ThreadSubscriptionsWorkerStore._should_skip_autosubscription_after_unsubscription + + # Order of arguments: + # automatic cause event: stream order, then topological order + # unsubscribe maximums: stream order, then tological order + + # both orderings agree that the unsub is after the cause event + self.assertTrue( + func(autosub=EventOrderings(1, 1), unsubscribed_at=EventOrderings(2, 2)) + ) + + # topological ordering is inconsistent with stream ordering, + # in that case favour stream ordering because it's what /sync uses + self.assertTrue( + func(autosub=EventOrderings(1, 2), unsubscribed_at=EventOrderings(2, 1)) + ) + + # the automatic subscription is caused by a backfilled event here + # unfortunately we must fall back to topological ordering here + self.assertTrue( + func(autosub=EventOrderings(-50, 2), unsubscribed_at=EventOrderings(2, 3)) + ) + self.assertFalse( + func(autosub=EventOrderings(-50, 2), unsubscribed_at=EventOrderings(2, 1)) + ) + + def test_get_subscribers_to_thread(self) -> None: + """ + Test getting all subscribers to a thread at once. + + To check cache invalidations are correct, we do multiple + step-by-step rounds of subscription changes and assertions. + """ + other_user_id = "@other_user:test" + + subscribers = self.get_success( + self.store.get_subscribers_to_thread(self.room_id, self.thread_root_id) + ) + self.assertEqual(subscribers, frozenset()) + + self._subscribe( + self.thread_root_id, automatic_event_orderings=None, user_id=self.user_id + ) + + subscribers = self.get_success( + self.store.get_subscribers_to_thread(self.room_id, self.thread_root_id) + ) + self.assertEqual(subscribers, frozenset((self.user_id,))) + + self._subscribe( + self.thread_root_id, automatic_event_orderings=None, user_id=other_user_id + ) + + subscribers = self.get_success( + self.store.get_subscribers_to_thread(self.room_id, self.thread_root_id) + ) + self.assertEqual(subscribers, frozenset((self.user_id, other_user_id))) + + self._unsubscribe(self.thread_root_id, user_id=self.user_id) + + subscribers = self.get_success( + self.store.get_subscribers_to_thread(self.room_id, self.thread_root_id) + ) + self.assertEqual(subscribers, frozenset((other_user_id,)))