From 5d35f02abb4664729a149cd7ba2715635bb4aed0 Mon Sep 17 00:00:00 2001 From: Valere Date: Tue, 7 Dec 2021 19:56:14 +0100 Subject: [PATCH] Support using unpublished fallback key instead of generating And forgetFallback after 5mn --- .../sdk/internal/crypto/MXOlmDevice.kt | 14 ++ .../internal/crypto/OneTimeKeysUploader.kt | 123 ++++++++++++------ 2 files changed, 97 insertions(+), 40 deletions(-) diff --git a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/MXOlmDevice.kt b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/MXOlmDevice.kt index 6479a8ddce..50f3e6acd0 100755 --- a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/MXOlmDevice.kt +++ b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/MXOlmDevice.kt @@ -136,6 +136,11 @@ internal class MXOlmDevice @Inject constructor( return store.getOlmAccount().maxOneTimeKeys() } + /** + * Returns an unpublished fallback key + * A call to markKeysAsPublished will mark it as published and this + * call will return null (until a call to generateFallbackKey is made) + */ fun getFallbackKey(): MutableMap>? { try { return store.getOlmAccount().fallbackKey() @@ -154,6 +159,15 @@ internal class MXOlmDevice @Inject constructor( } } + fun forgetFallbackKey() { + try { + store.getOlmAccount().forgetFallbackKey() + store.saveOlmAccount() + } catch (e: Exception) { + Timber.e("## forgetFallbackKey() : failed") + } + } + /** * Release the instance */ diff --git a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/OneTimeKeysUploader.kt b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/OneTimeKeysUploader.kt index 7759e04c7c..9366ecbd6d 100644 --- a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/OneTimeKeysUploader.kt +++ b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/OneTimeKeysUploader.kt @@ -16,6 +16,7 @@ package org.matrix.android.sdk.internal.crypto +import android.content.Context import org.matrix.android.sdk.api.extensions.tryOrNull import org.matrix.android.sdk.internal.crypto.model.MXKey import org.matrix.android.sdk.internal.crypto.model.rest.KeysUploadResponse @@ -28,11 +29,14 @@ import javax.inject.Inject import kotlin.math.floor import kotlin.math.min +const val FIVE_MINUTES = 5 * 60_000L + @SessionScope internal class OneTimeKeysUploader @Inject constructor( private val olmDevice: MXOlmDevice, private val objectSigner: ObjectSigner, - private val uploadKeysTask: UploadKeysTask + private val uploadKeysTask: UploadKeysTask, + context: Context ) { // tell if there is a OTK check in progress private var oneTimeKeyCheckInProgress = false @@ -42,6 +46,9 @@ internal class OneTimeKeysUploader @Inject constructor( private var oneTimeKeyCount: Int? = null private var needNewFallbackKey: Boolean = false + // Simple storage to remember when was uploaded the last fallback key + private val storage = context.getSharedPreferences("OneTimeKeysUploader_${olmDevice.deviceEd25519Key.hashCode()}", Context.MODE_PRIVATE) + /** * Stores the current one_time_key count which will be handled later (in a call of * _onSyncCompleted). The count is e.g. coming from a /sync response. @@ -70,9 +77,19 @@ internal class OneTimeKeysUploader @Inject constructor( return } - lastOneTimeKeyCheck = System.currentTimeMillis() oneTimeKeyCheckInProgress = true + val oneTimeKeyCountFromSync = oneTimeKeyCount + ?: fetchOtkCount() // we don't have count from sync so get from server + ?: return Unit.also { + oneTimeKeyCheckInProgress = false + Timber.w("maybeUploadOneTimeKeys: Failed to get otk count from server") + } + + Timber.d("maybeUploadOneTimeKeys: otk count $oneTimeKeyCountFromSync , needs fallback key $needNewFallbackKey") + + lastOneTimeKeyCheck = System.currentTimeMillis() + // We then check how many keys we can store in the Account object. val maxOneTimeKeys = olmDevice.getMaxNumberOfOneTimeKeys() @@ -83,32 +100,32 @@ internal class OneTimeKeysUploader @Inject constructor( // discard the oldest private keys first. This will eventually clean // out stale private keys that won't receive a message. val keyLimit = floor(maxOneTimeKeys / 2.0).toInt() - if (oneTimeKeyCount == null) { - // Ask the server how many otk he has - oneTimeKeyCount = fetchOtkCount() - } - val oneTimeKeyCountFromSync = oneTimeKeyCount - if (oneTimeKeyCountFromSync != null) { - // We need to keep a pool of one time public keys on the server so that - // other devices can start conversations with us. But we can only store - // a finite number of private keys in the olm Account object. - // To complicate things further then can be a delay between a device - // claiming a public one time key from the server and it sending us a - // message. We need to keep the corresponding private key locally until - // we receive the message. - // But that message might never arrive leaving us stuck with duff - // private keys clogging up our local storage. - // So we need some kind of engineering compromise to balance all of - // these factors. - tryOrNull("Unable to upload OTK") { - val uploadedKeys = uploadOTK(oneTimeKeyCountFromSync, keyLimit) - Timber.v("## uploadKeys() : success, $uploadedKeys key(s) sent") - } - } else { - Timber.w("maybeUploadOneTimeKeys: waiting to know the number of OTK from the sync") - lastOneTimeKeyCheck = 0 + + // We need to keep a pool of one time public keys on the server so that + // other devices can start conversations with us. But we can only store + // a finite number of private keys in the olm Account object. + // To complicate things further then can be a delay between a device + // claiming a public one time key from the server and it sending us a + // message. We need to keep the corresponding private key locally until + // we receive the message. + // But that message might never arrive leaving us stuck with duff + // private keys clogging up our local storage. + // So we need some kind of engineering compromise to balance all of + // these factors. + tryOrNull("Unable to upload OTK") { + val uploadedKeys = uploadOTK(oneTimeKeyCountFromSync, keyLimit) + Timber.v("## uploadKeys() : success, $uploadedKeys key(s) sent") } oneTimeKeyCheckInProgress = false + + // Check if we need to forget a fallback key + val latestPublishedTime = getLastFallbackKeyPublishTime() + if (latestPublishedTime != 0L && System.currentTimeMillis() - latestPublishedTime > FIVE_MINUTES) { + // This should be called once you are reasonably certain that you will not receive any more messages + // that use the old fallback key (e.g. 5 minutes after the new fallback key has been published) + Timber.d("## forgetFallbackKey()") + olmDevice.forgetFallbackKey() + } } private suspend fun fetchOtkCount(): Int? { @@ -138,24 +155,52 @@ internal class OneTimeKeysUploader @Inject constructor( keysThisLoop = min(keyLimit - keyCount, ONE_TIME_KEY_GENERATION_MAX_NUMBER) olmDevice.generateOneTimeKeys(keysThisLoop) } - if (needNewFallbackKey) { - Timber.d("## CRYPTO: New fallback key needed") + if (needNewFallbackKey && !hasUnpublishedFallbackKey()) { + // if there is already fallback key, but that hasn't been published yet, we + // can use that instead of generating a new one olmDevice.generateFallbackKey() + Timber.d("maybeUploadOneTimeKeys: Fallback key generated") + // As we generated a new one, it's already forgetting one + // so we can clear the last publish time + // (in case the network calls fails after to avoid calling forgetKey) + saveLastFallbackKeyPublishTime(0L) } + // not copy paste error we check before sending if there is + // an unpublished key in order to saveLastFallbackKeyPublishTime if needed + val hadUnpublishedFallbackKey = hasUnpublishedFallbackKey() val response = uploadOneTimeKeys(olmDevice.getOneTimeKeys()) olmDevice.markKeysAsPublished() + if (hadUnpublishedFallbackKey) { + // It had an unpublished fallback key that was published just now + saveLastFallbackKeyPublishTime(System.currentTimeMillis()) + } + needNewFallbackKey = false if (response.hasOneTimeKeyCountsForAlgorithm(MXKey.KEY_SIGNED_CURVE_25519_TYPE)) { // Maybe upload other keys - return keysThisLoop + uploadOTK(response.oneTimeKeyCountsForAlgorithm(MXKey.KEY_SIGNED_CURVE_25519_TYPE), keyLimit) + return keysThisLoop + + uploadOTK(response.oneTimeKeyCountsForAlgorithm(MXKey.KEY_SIGNED_CURVE_25519_TYPE), keyLimit) + + (if (hadUnpublishedFallbackKey) 1 else 0) } else { Timber.e("## uploadOTK() : response for uploading keys does not contain one_time_key_counts.signed_curve25519") throw Exception("response for uploading keys does not contain one_time_key_counts.signed_curve25519") } } + private fun hasUnpublishedFallbackKey(): Boolean { + return olmDevice.getFallbackKey()?.get(OlmAccount.JSON_KEY_ONE_TIME_KEY).orEmpty().isNotEmpty() + } + + private fun saveLastFallbackKeyPublishTime(timeMillis: Long) { + storage.edit().putLong("last_fb_key_publish", timeMillis).apply() + } + + private fun getLastFallbackKeyPublishTime(): Long { + return storage.getLong("last_fb_key_publish", 0) + } + /** * Upload curve25519 one time keys. */ @@ -177,17 +222,15 @@ internal class OneTimeKeysUploader @Inject constructor( } val fallbackJson = mutableMapOf() - if (needNewFallbackKey) { - val fallbackCurve25519Map = olmDevice.getFallbackKey()?.get(OlmAccount.JSON_KEY_ONE_TIME_KEY).orEmpty() - fallbackCurve25519Map.forEach { (key_id, key) -> - val k = mutableMapOf() - k["key"] = key - k["fallback"] = true - val canonicalJson = JsonCanonicalizer.getCanonicalJson(Map::class.java, k) - k["signatures"] = objectSigner.signObject(canonicalJson) + val fallbackCurve25519Map = olmDevice.getFallbackKey()?.get(OlmAccount.JSON_KEY_ONE_TIME_KEY).orEmpty() + fallbackCurve25519Map.forEach { (key_id, key) -> + val k = mutableMapOf() + k["key"] = key + k["fallback"] = true + val canonicalJson = JsonCanonicalizer.getCanonicalJson(Map::class.java, k) + k["signatures"] = objectSigner.signObject(canonicalJson) - fallbackJson["signed_curve25519:$key_id"] = k - } + fallbackJson["signed_curve25519:$key_id"] = k } // For now, we set the device id explicitly, as we may not be using the @@ -208,6 +251,6 @@ internal class OneTimeKeysUploader @Inject constructor( private const val ONE_TIME_KEY_GENERATION_MAX_NUMBER = 5 // frequency with which to check & upload one-time keys - private const val ONE_TIME_KEY_UPLOAD_PERIOD = (60 * 1000).toLong() // one minute + private const val ONE_TIME_KEY_UPLOAD_PERIOD = (60_000).toLong() // one minute } }