diff --git a/coroutines/src/commonMain/kotlin/dev/inmo/micro_utils/coroutines/SmartKeyRWLocker.kt b/coroutines/src/commonMain/kotlin/dev/inmo/micro_utils/coroutines/SmartKeyRWLocker.kt new file mode 100644 index 00000000000..eb006810bf4 --- /dev/null +++ b/coroutines/src/commonMain/kotlin/dev/inmo/micro_utils/coroutines/SmartKeyRWLocker.kt @@ -0,0 +1,79 @@ +package dev.inmo.micro_utils.coroutines + +import kotlinx.coroutines.sync.withLock +import kotlin.contracts.ExperimentalContracts +import kotlin.contracts.InvocationKind +import kotlin.contracts.contract + +class SmartKeyRWLocker( + private val perKeyReadPermits: Int = Int.MAX_VALUE +) { + private val internalRWLocker = SmartRWLocker() + private val lockers = mutableMapOf() + + private fun allocateLockerWithoutLock(key: T) = lockers.getOrPut(key) { + SmartRWLocker(perKeyReadPermits) + } + + suspend fun writeMutex(key: T): SmartMutex.Immutable = internalRWLocker.withReadAcquire { + allocateLockerWithoutLock(key).writeMutex + } + suspend fun readSemaphore(key: T): SmartSemaphore.Immutable = internalRWLocker.withReadAcquire { + allocateLockerWithoutLock(key).readSemaphore + } + fun writeMutexOrNull(key: T): SmartMutex.Immutable? = lockers[key] ?.writeMutex + fun readSemaphoreOrNull(key: T): SmartSemaphore.Immutable? = lockers[key] ?.readSemaphore + + suspend fun acquireRead(key: T) { + internalRWLocker.withReadAcquire { + val locker = allocateLockerWithoutLock(key) + locker.acquireRead() + } + } + suspend fun releaseRead(key: T): Boolean { + return internalRWLocker.withReadAcquire { + lockers[key] + } ?.releaseRead() == true + } + + suspend fun lockWrite(key: T) { + internalRWLocker.withWriteLock { + val locker = allocateLockerWithoutLock(key) + locker.lockWrite() + } + } + suspend fun unlockWrite(key: T): Boolean { + return internalRWLocker.withWriteLock { + lockers[key] + } ?.unlockWrite() == true + } + fun isWriteLocked(key: T): Boolean = lockers[key] ?.writeMutex ?.isLocked == true +} + +@OptIn(ExperimentalContracts::class) +suspend inline fun SmartKeyRWLocker.withReadAcquire(key: T, action: () -> R): R { + contract { + callsInPlace(action, InvocationKind.EXACTLY_ONCE) + } + + acquireRead(key) + try { + return action() + } finally { + releaseRead(key) + } +} + +@OptIn(ExperimentalContracts::class) +suspend inline fun SmartKeyRWLocker.withWriteLock(key: T, action: () -> R): R { + contract { + callsInPlace(action, InvocationKind.EXACTLY_ONCE) + } + + lockWrite(key) + try { + return action() + } finally { + unlockWrite(key) + } +} \ No newline at end of file diff --git a/coroutines/src/commonTest/kotlin/RealTimeOut.kt b/coroutines/src/commonTest/kotlin/RealTimeOut.kt new file mode 100644 index 00000000000..ee598385977 --- /dev/null +++ b/coroutines/src/commonTest/kotlin/RealTimeOut.kt @@ -0,0 +1,12 @@ +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.withContext +import kotlinx.coroutines.withTimeout +import kotlin.time.Duration + +suspend fun realWithTimeout(time: Duration, block: suspend () -> T): T { + return withContext(Dispatchers.Default.limitedParallelism(1)) { + withTimeout(time) { + block() + } + } +} diff --git a/coroutines/src/commonTest/kotlin/SmartKeyRWLockerTests.kt b/coroutines/src/commonTest/kotlin/SmartKeyRWLockerTests.kt new file mode 100644 index 00000000000..d4c8981a87b --- /dev/null +++ b/coroutines/src/commonTest/kotlin/SmartKeyRWLockerTests.kt @@ -0,0 +1,80 @@ +import dev.inmo.micro_utils.coroutines.* +import kotlinx.coroutines.* +import kotlinx.coroutines.flow.first +import kotlinx.coroutines.sync.Mutex +import kotlinx.coroutines.sync.withLock +import kotlinx.coroutines.test.runTest +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFails +import kotlin.test.assertFalse +import kotlin.test.assertTrue +import kotlin.time.Duration.Companion.milliseconds +import kotlin.time.Duration.Companion.seconds + +class SmartKeyRWLockerTests { + @Test + fun lockKeyFailedOnGlobalLockTest() = runTest { + val locker = SmartKeyRWLocker() + val testKey = "test" + locker.lockWrite() + + assertTrue { locker.isWriteLocked() } + + assertFails { + realWithTimeout(1.seconds) { + locker.lockWrite(testKey) + } + } + assertFalse { locker.isWriteLocked(testKey) } + + locker.unlockWrite() + assertFalse { locker.isWriteLocked() } + + realWithTimeout(1.seconds) { + locker.lockWrite(testKey) + } + assertTrue { locker.isWriteLocked(testKey) } + assertTrue { locker.unlockWrite(testKey) } + assertFalse { locker.isWriteLocked(testKey) } + } + @Test + fun readsBlockingGlobalWrite() = runTest { + val locker = SmartKeyRWLocker() + + val testKeys = (0 until 100).map { "test$it" } + + for (i in testKeys.indices) { + val it = testKeys[i] + locker.acquireRead(it) + val previous = testKeys.take(i) + val next = testKeys.drop(i + 1) + + previous.forEach { + assertTrue { locker.readSemaphoreOrNull(it) ?.freePermits == Int.MAX_VALUE - 1 } + } + next.forEach { + assertTrue { locker.readSemaphoreOrNull(it) ?.freePermits == null } + } + } + + for (i in testKeys.indices) { + val it = testKeys[i] + assertFails { + realWithTimeout(13.milliseconds) { locker.lockWrite() } + } + realWithTimeout(1.seconds) { locker.acquireRead() } + locker.releaseRead() + assertTrue { locker.readSemaphore().freePermits == Int.MAX_VALUE } + + locker.releaseRead(it) + } + + realWithTimeout(1.seconds) { locker.lockWrite() } + assertFails { + realWithTimeout(13.milliseconds) { locker.acquireRead() } + } + assertTrue { locker.unlockWrite() } + assertTrue { locker.readSemaphore().freePermits == Int.MAX_VALUE } + } +}