diff --git a/coroutines/src/commonMain/kotlin/dev/inmo/micro_utils/coroutines/SmartKeyRWLocker.kt b/coroutines/src/commonMain/kotlin/dev/inmo/micro_utils/coroutines/SmartKeyRWLocker.kt index eb006810bf4..ac296566760 100644 --- a/coroutines/src/commonMain/kotlin/dev/inmo/micro_utils/coroutines/SmartKeyRWLocker.kt +++ b/coroutines/src/commonMain/kotlin/dev/inmo/micro_utils/coroutines/SmartKeyRWLocker.kt @@ -1,55 +1,143 @@ package dev.inmo.micro_utils.coroutines +import kotlinx.coroutines.CancellationException +import kotlinx.coroutines.sync.Mutex import kotlinx.coroutines.sync.withLock import kotlin.contracts.ExperimentalContracts import kotlin.contracts.InvocationKind import kotlin.contracts.contract class SmartKeyRWLocker( + globalLockerReadPermits: Int = Int.MAX_VALUE, + globalLockerWriteIsLocked: Boolean = false, private val perKeyReadPermits: Int = Int.MAX_VALUE ) { - private val internalRWLocker = SmartRWLocker() + private val globalRWLocker: SmartRWLocker = SmartRWLocker( + readPermits = globalLockerReadPermits, + writeIsLocked = globalLockerWriteIsLocked + ) private val lockers = mutableMapOf() + private val lockersMutex = Mutex() + private val lockersWritingLocker = SmartSemaphore.Mutable(Int.MAX_VALUE) + private val globalWritingLocker = SmartSemaphore.Mutable(Int.MAX_VALUE) private fun allocateLockerWithoutLock(key: T) = lockers.getOrPut(key) { SmartRWLocker(perKeyReadPermits) } + private suspend fun allocateLocker(key: T) = lockersMutex.withLock { + lockers.getOrPut(key) { + SmartRWLocker(perKeyReadPermits) + } + } - suspend fun writeMutex(key: T): SmartMutex.Immutable = internalRWLocker.withReadAcquire { + suspend fun writeMutex(key: T): SmartMutex.Immutable = globalRWLocker.withReadAcquire { allocateLockerWithoutLock(key).writeMutex } - suspend fun readSemaphore(key: T): SmartSemaphore.Immutable = internalRWLocker.withReadAcquire { + suspend fun readSemaphore(key: T): SmartSemaphore.Immutable = globalRWLocker.withReadAcquire { allocateLockerWithoutLock(key).readSemaphore } fun writeMutexOrNull(key: T): SmartMutex.Immutable? = lockers[key] ?.writeMutex fun readSemaphoreOrNull(key: T): SmartSemaphore.Immutable? = lockers[key] ?.readSemaphore + fun writeMutex(): SmartMutex.Immutable = globalRWLocker.writeMutex + fun readSemaphore(): SmartSemaphore.Immutable = globalRWLocker.readSemaphore + + suspend fun acquireRead() { + globalWritingLocker.acquire() + try { + lockersWritingLocker.waitReleaseAll() + globalRWLocker.acquireRead() + } catch (e: CancellationException) { + globalWritingLocker.release() + throw e + } + } + suspend fun releaseRead(): Boolean { + globalWritingLocker.release() + return globalRWLocker.releaseRead() + } + + suspend fun lockWrite() { + globalRWLocker.lockWrite() + } + suspend fun unlockWrite(): Boolean { + return globalRWLocker.unlockWrite() + } + fun isWriteLocked(): Boolean = globalRWLocker.writeMutex.isLocked == true + + suspend fun acquireRead(key: T) { - internalRWLocker.withReadAcquire { - val locker = allocateLockerWithoutLock(key) + globalRWLocker.acquireRead() + val locker = allocateLocker(key) + try { locker.acquireRead() + } catch (e: CancellationException) { + globalRWLocker.releaseRead() + throw e } } suspend fun releaseRead(key: T): Boolean { - return internalRWLocker.withReadAcquire { - lockers[key] - } ?.releaseRead() == true + val locker = allocateLocker(key) + return locker.releaseRead() && globalRWLocker.releaseRead() } suspend fun lockWrite(key: T) { - internalRWLocker.withWriteLock { - val locker = allocateLockerWithoutLock(key) - locker.lockWrite() + globalWritingLocker.withAcquire(globalWritingLocker.maxPermits) { + lockersWritingLocker.acquire() + } + try { + globalRWLocker.acquireRead() + try { + val locker = allocateLocker(key) + locker.lockWrite() + } catch (e: CancellationException) { + globalRWLocker.releaseRead() + throw e + } + } catch (e: CancellationException) { + lockersWritingLocker.release() + throw e } } suspend fun unlockWrite(key: T): Boolean { - return internalRWLocker.withWriteLock { - lockers[key] - } ?.unlockWrite() == true + val locker = allocateLocker(key) + return (locker.unlockWrite() && globalRWLocker.releaseRead()).also { + if (it) { + lockersWritingLocker.release() + } + } } fun isWriteLocked(key: T): Boolean = lockers[key] ?.writeMutex ?.isLocked == true } +@OptIn(ExperimentalContracts::class) +suspend inline fun SmartKeyRWLocker.withReadAcquire(action: () -> R): R { + contract { + callsInPlace(action, InvocationKind.EXACTLY_ONCE) + } + + acquireRead() + try { + return action() + } finally { + releaseRead() + } +} + +@OptIn(ExperimentalContracts::class) +suspend inline fun SmartKeyRWLocker.withWriteLock(action: () -> R): R { + contract { + callsInPlace(action, InvocationKind.EXACTLY_ONCE) + } + + lockWrite() + try { + return action() + } finally { + unlockWrite() + } +} + @OptIn(ExperimentalContracts::class) suspend inline fun SmartKeyRWLocker.withReadAcquire(key: T, action: () -> R): R { contract { diff --git a/coroutines/src/commonMain/kotlin/dev/inmo/micro_utils/coroutines/SmartRWLocker.kt b/coroutines/src/commonMain/kotlin/dev/inmo/micro_utils/coroutines/SmartRWLocker.kt index 08069672cd6..60391978d81 100644 --- a/coroutines/src/commonMain/kotlin/dev/inmo/micro_utils/coroutines/SmartRWLocker.kt +++ b/coroutines/src/commonMain/kotlin/dev/inmo/micro_utils/coroutines/SmartRWLocker.kt @@ -1,5 +1,6 @@ package dev.inmo.micro_utils.coroutines +import kotlinx.coroutines.CancellationException import kotlin.contracts.ExperimentalContracts import kotlin.contracts.InvocationKind import kotlin.contracts.contract @@ -39,7 +40,12 @@ class SmartRWLocker(private val readPermits: Int = Int.MAX_VALUE, writeIsLocked: */ suspend fun lockWrite() { _writeMutex.lock() - _readSemaphore.acquire(readPermits) + try { + _readSemaphore.acquire(readPermits) + } catch (e: CancellationException) { + _writeMutex.unlock() + throw e + } } /** diff --git a/coroutines/src/commonMain/kotlin/dev/inmo/micro_utils/coroutines/SmartSemaphore.kt b/coroutines/src/commonMain/kotlin/dev/inmo/micro_utils/coroutines/SmartSemaphore.kt index 29016f00a11..cef80cfab16 100644 --- a/coroutines/src/commonMain/kotlin/dev/inmo/micro_utils/coroutines/SmartSemaphore.kt +++ b/coroutines/src/commonMain/kotlin/dev/inmo/micro_utils/coroutines/SmartSemaphore.kt @@ -24,6 +24,7 @@ import kotlin.contracts.contract * [Mutable] creator */ sealed interface SmartSemaphore { + val maxPermits: Int val permitsStateFlow: StateFlow /** @@ -36,7 +37,7 @@ sealed interface SmartSemaphore { /** * Immutable variant of [SmartSemaphore]. In fact will depend on the owner of [permitsStateFlow] */ - class Immutable(override val permitsStateFlow: StateFlow) : SmartSemaphore + class Immutable(override val permitsStateFlow: StateFlow, override val maxPermits: Int) : SmartSemaphore /** * Mutable variant of [SmartSemaphore]. With that variant you may [lock] and [unlock]. Besides, you may create @@ -44,15 +45,16 @@ sealed interface SmartSemaphore { * * @param locked Preset state of [freePermits] and its internal [_freePermitsStateFlow] */ - class Mutable(private val permits: Int, acquiredPermits: Int = 0) : SmartSemaphore { + class Mutable(permits: Int, acquiredPermits: Int = 0) : SmartSemaphore { + override val maxPermits: Int = permits private val _freePermitsStateFlow = SpecialMutableStateFlow(permits - acquiredPermits) override val permitsStateFlow: StateFlow = _freePermitsStateFlow.asStateFlow() private val internalChangesMutex = Mutex(false) - fun immutable() = Immutable(permitsStateFlow) + fun immutable() = Immutable(permitsStateFlow, maxPermits) - private fun checkedPermits(permits: Int) = permits.coerceIn(1 .. this.permits) + private fun checkedPermits(permits: Int) = permits.coerceIn(1 .. this.maxPermits) /** * Holds call until this [SmartSemaphore] will be re-locked. That means that current method will @@ -126,10 +128,10 @@ sealed interface SmartSemaphore { */ suspend fun release(permits: Int = 1): Boolean { val checkedPermits = checkedPermits(permits) - return if (_freePermitsStateFlow.value < this.permits) { + return if (_freePermitsStateFlow.value < this.maxPermits) { internalChangesMutex.withLock { - if (_freePermitsStateFlow.value < this.permits) { - _freePermitsStateFlow.value = minOf(_freePermitsStateFlow.value + checkedPermits, this.permits) + if (_freePermitsStateFlow.value < this.maxPermits) { + _freePermitsStateFlow.value = minOf(_freePermitsStateFlow.value + checkedPermits, this.maxPermits) true } else { false @@ -166,3 +168,4 @@ suspend inline fun SmartSemaphore.Mutable.withAcquire(permits: Int = 1, acti * the fact that some other parties may lock it again */ suspend fun SmartSemaphore.waitRelease(permits: Int = 1) = permitsStateFlow.first { it >= permits } +suspend fun SmartSemaphore.waitReleaseAll() = permitsStateFlow.first { it == maxPermits } diff --git a/coroutines/src/commonTest/kotlin/SmartKeyRWLockerTests.kt b/coroutines/src/commonTest/kotlin/SmartKeyRWLockerTests.kt index d4c8981a87b..5fcfb7850f7 100644 --- a/coroutines/src/commonTest/kotlin/SmartKeyRWLockerTests.kt +++ b/coroutines/src/commonTest/kotlin/SmartKeyRWLockerTests.kt @@ -63,13 +63,54 @@ class SmartKeyRWLockerTests { assertFails { realWithTimeout(13.milliseconds) { locker.lockWrite() } } + val readPermitsBeforeLock = locker.readSemaphore().freePermits realWithTimeout(1.seconds) { locker.acquireRead() } locker.releaseRead() - assertTrue { locker.readSemaphore().freePermits == Int.MAX_VALUE } + assertEquals(readPermitsBeforeLock, locker.readSemaphore().freePermits) locker.releaseRead(it) } + assertTrue { locker.readSemaphore().freePermits == Int.MAX_VALUE } + realWithTimeout(1.seconds) { locker.lockWrite() } + assertFails { + realWithTimeout(13.milliseconds) { locker.acquireRead() } + } + assertTrue { locker.unlockWrite() } + assertTrue { locker.readSemaphore().freePermits == Int.MAX_VALUE } + } + @Test + fun writesBlockingGlobalWrite() = runTest { + val locker = SmartKeyRWLocker() + + val testKeys = (0 until 100).map { "test$it" } + + for (i in testKeys.indices) { + val it = testKeys[i] + locker.lockWrite(it) + val previous = testKeys.take(i) + val next = testKeys.drop(i + 1) + + previous.forEach { + assertTrue { locker.writeMutexOrNull(it) ?.isLocked == true } + } + next.forEach { + assertTrue { locker.writeMutexOrNull(it) ?.isLocked != true } + } + } + + for (i in testKeys.indices) { + val it = testKeys[i] + assertFails { realWithTimeout(13.milliseconds) { locker.lockWrite() } } + + val readPermitsBeforeLock = locker.readSemaphore().freePermits + assertFails { realWithTimeout(13.milliseconds) { locker.acquireRead() } } + assertEquals(readPermitsBeforeLock, locker.readSemaphore().freePermits) + + locker.unlockWrite(it) + } + + assertTrue { locker.readSemaphore().freePermits == Int.MAX_VALUE } realWithTimeout(1.seconds) { locker.lockWrite() } assertFails { realWithTimeout(13.milliseconds) { locker.acquireRead() } diff --git a/coroutines/src/commonTest/kotlin/SmartRWLockerTests.kt b/coroutines/src/commonTest/kotlin/SmartRWLockerTests.kt index 75fd251d17b..210ce222b15 100644 --- a/coroutines/src/commonTest/kotlin/SmartRWLockerTests.kt +++ b/coroutines/src/commonTest/kotlin/SmartRWLockerTests.kt @@ -6,7 +6,10 @@ import kotlinx.coroutines.sync.withLock import kotlinx.coroutines.test.runTest import kotlin.test.Test import kotlin.test.assertEquals +import kotlin.test.assertFails +import kotlin.test.assertFalse import kotlin.test.assertTrue +import kotlin.time.Duration.Companion.seconds class SmartRWLockerTests { @Test @@ -148,4 +151,17 @@ class SmartRWLockerTests { assertEquals(false, locker.writeMutex.isLocked) } } + + @Test + fun exceptionOnLockingWillNotLockLocker() = runTest { + val locker = SmartRWLocker() + + locker.acquireRead() + assertFails { + realWithTimeout(1.seconds) { + locker.lockWrite() + } + } + assertFalse { locker.writeMutex.isLocked } + } } diff --git a/transactions/src/commonMain/kotlin/TransactionsDSL.kt b/transactions/src/commonMain/kotlin/TransactionsDSL.kt index 566a4aca628..a743b03c4a9 100644 --- a/transactions/src/commonMain/kotlin/TransactionsDSL.kt +++ b/transactions/src/commonMain/kotlin/TransactionsDSL.kt @@ -19,7 +19,7 @@ class RollbackContext internal constructor ( * * @param rollback Will be called if */ -suspend fun TransactionsDSL.rollbackableOperation( +suspend fun TransactionsDSL.rollableBackOperation( rollback: suspend RollbackContext.() -> Unit, action: suspend () -> T ): T { @@ -34,7 +34,7 @@ suspend fun TransactionsDSL.rollbackableOperation( } /** - * Starts transaction with opportunity to add actions [rollbackableOperation]. How to use: + * Starts transaction with opportunity to add actions [rollableBackOperation]. How to use: * * ```kotlin * doSuspendTransaction { diff --git a/transactions/src/commonTest/kotlin/TransactionsDSLTests.kt b/transactions/src/commonTest/kotlin/TransactionsDSLTests.kt index 5c1ab7c9973..cea9a944df3 100644 --- a/transactions/src/commonTest/kotlin/TransactionsDSLTests.kt +++ b/transactions/src/commonTest/kotlin/TransactionsDSLTests.kt @@ -1,5 +1,5 @@ import dev.inmo.micro_utils.transactions.doSuspendTransaction -import dev.inmo.micro_utils.transactions.rollbackableOperation +import dev.inmo.micro_utils.transactions.rollableBackOperation import kotlinx.coroutines.test.runTest import kotlin.test.Test import kotlin.test.assertEquals @@ -19,7 +19,7 @@ class TransactionsDSLTests { val actionResult = doSuspendTransaction { dataCollections.forEachIndexed { i, _ -> - val resultData = rollbackableOperation({ + val resultData = rollableBackOperation({ dataCollections[i] = actionResult.copy(second = true) }) { val result = dataCollections[i] @@ -56,7 +56,7 @@ class TransactionsDSLTests { val actionResult = doSuspendTransaction { dataCollections.forEachIndexed { i, _ -> - val resultData = rollbackableOperation({ + val resultData = rollableBackOperation({ assertTrue(error === this.error) dataCollections[i] = actionResult.copy(second = true) }) {