diff --git a/app/src/main/kotlin/com/wire/android/ui/WireActivity.kt b/app/src/main/kotlin/com/wire/android/ui/WireActivity.kt index 85bba9d230..34b126c2a8 100644 --- a/app/src/main/kotlin/com/wire/android/ui/WireActivity.kt +++ b/app/src/main/kotlin/com/wire/android/ui/WireActivity.kt @@ -137,6 +137,7 @@ import com.wire.android.util.debug.LocalFeatureVisibilityFlags import com.wire.android.util.launchUpdateTheApp import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.channels.Channel +import kotlinx.coroutines.delay import kotlinx.coroutines.flow.collectLatest import kotlinx.coroutines.flow.first import kotlinx.coroutines.flow.receiveAsFlow @@ -259,13 +260,21 @@ class WireActivity : BaseActivity() { val snackbarHostState = remember { SnackbarHostState() } val currentUserId = viewModel.globalAppState.currentUserId val appGraph = LocalContext.current.wireApplicationGraph - val sessionViewModelGraph = remember(appGraph, currentUserId) { + val currentSessionViewModelGraph = remember(appGraph, currentUserId) { currentUserId?.let { appGraph.sessionViewModelGraph } } + var retainedSessionViewModelGraph by remember(appGraph) { + mutableStateOf(null) + } + LaunchedEffect(currentSessionViewModelGraph) { + currentSessionViewModelGraph?.let { + retainedSessionViewModelGraph = it + } + } val authenticationViewModelGraph = remember(appGraph) { appGraph.authenticationViewModelGraph } - val activityViewModels = sessionViewModelGraph?.let { + val activityViewModels = currentSessionViewModelGraph?.let { wireActivityScopedViewModels(it) } @@ -305,6 +314,14 @@ class WireActivity : BaseActivity() { ?.destination ?.route ?.getBaseRoute() + val effectiveBaseRoute = currentBaseRoute ?: startDestination.baseRoute + LaunchedEffect(currentUserId, effectiveBaseRoute) { + if (currentUserId == null && effectiveBaseRoute in authenticationGraphRoutes) { + delay(SESSION_GRAPH_RELEASE_DELAY_MILLIS) + retainedSessionViewModelGraph = null + } + } + val sessionViewModelGraph = currentSessionViewModelGraph ?: retainedSessionViewModelGraph val metroViewModelGraph = rememberMetroViewModelGraph( currentBaseRoute = currentBaseRoute, startDestinationBaseRoute = startDestination.baseRoute, @@ -879,6 +896,7 @@ class WireActivity : BaseActivity() { companion object { private const val HANDLED_DEEPLINK_FLAG = "deeplink_handled_flag_key" private const val ORIGINAL_SAVED_INTENT_FLAG = "original_saved_intent" + private const val SESSION_GRAPH_RELEASE_DELAY_MILLIS = 500L private const val TAG = "WireActivity" } } diff --git a/app/src/main/kotlin/com/wire/android/ui/authentication/AuthenticationViewModelGraph.kt b/app/src/main/kotlin/com/wire/android/ui/authentication/AuthenticationViewModelGraph.kt index 50bf923484..f98db62193 100644 --- a/app/src/main/kotlin/com/wire/android/ui/authentication/AuthenticationViewModelGraph.kt +++ b/app/src/main/kotlin/com/wire/android/ui/authentication/AuthenticationViewModelGraph.kt @@ -60,6 +60,7 @@ inline fun authenticationViewModel( metroViewModel( viewModelStoreOwner = viewModelStoreOwner, key = key, + scopeKeyOverride = AUTHENTICATION_VIEW_MODEL_SCOPE_KEY, ) { authenticationViewModelFactory.create() } @@ -75,10 +76,13 @@ inline fun authenticationSavedStateViewModel( metroSavedStateViewModel( viewModelStoreOwner = viewModelStoreOwner, key = key, + scopeKeyOverride = AUTHENTICATION_VIEW_MODEL_SCOPE_KEY, ) { savedStateHandle -> authenticationViewModelFactory.create(savedStateHandle) } +const val AUTHENTICATION_VIEW_MODEL_SCOPE_KEY = "authentication" + @Composable fun welcomeViewModel(): WelcomeViewModel = authenticationSavedStateViewModel { welcomeViewModel(it) } diff --git a/core/di/src/main/kotlin/com/wire/android/di/metro/MetroViewModelGraph.kt b/core/di/src/main/kotlin/com/wire/android/di/metro/MetroViewModelGraph.kt index 32794ff534..22fc5286d7 100644 --- a/core/di/src/main/kotlin/com/wire/android/di/metro/MetroViewModelGraph.kt +++ b/core/di/src/main/kotlin/com/wire/android/di/metro/MetroViewModelGraph.kt @@ -51,6 +51,7 @@ inline fun metroViewModel( "No ViewModelStoreOwner was provided via LocalViewModelStoreOwner" }, key: String? = null, + scopeKeyOverride: String? = null, crossinline create: Graph.() -> VM, ): VM where Graph : MetroViewModelGraph, VM : ViewModel { val graph = checkNotNull(LocalMetroViewModelGraph.current as? Graph) { @@ -59,7 +60,7 @@ inline fun metroViewModel( val scopedKey = scopedMetroViewModelKey( defaultKey = VM::class.qualifiedName, key = key, - scopeKey = graph.viewModelScopeKey, + scopeKey = scopeKeyOverride ?: graph.viewModelScopeKey, ) val factory = remember(graph) { viewModelFactory { @@ -82,6 +83,7 @@ inline fun metroSavedStateViewModel( "No ViewModelStoreOwner was provided via LocalViewModelStoreOwner" }, key: String? = null, + scopeKeyOverride: String? = null, crossinline create: Graph.(SavedStateHandle) -> VM, ): VM where Graph : MetroViewModelGraph, VM : ViewModel { val graph = checkNotNull(LocalMetroViewModelGraph.current as? Graph) { @@ -90,7 +92,7 @@ inline fun metroSavedStateViewModel( val scopedKey = scopedMetroViewModelKey( defaultKey = VM::class.qualifiedName, key = key, - scopeKey = graph.viewModelScopeKey, + scopeKey = scopeKeyOverride ?: graph.viewModelScopeKey, ) val factory = remember(graph) { viewModelFactory {