diff --git a/coroutines/src/commonMain/kotlin/dev/inmo/micro_utils/coroutines/SmartRWLocker.kt b/coroutines/src/commonMain/kotlin/dev/inmo/micro_utils/coroutines/SmartRWLocker.kt index 0e79aca6f21..2ca18d8d4ed 100644 --- a/coroutines/src/commonMain/kotlin/dev/inmo/micro_utils/coroutines/SmartRWLocker.kt +++ b/coroutines/src/commonMain/kotlin/dev/inmo/micro_utils/coroutines/SmartRWLocker.kt @@ -12,9 +12,9 @@ import kotlin.contracts.contract * * [lockWrite] will lock [writeMutex] and then await while all [readSemaphore] will be freed * * [unlockWrite] will just unlock [writeMutex] */ -class SmartRWLocker(private val readPermits: Int = Int.MAX_VALUE) { +class SmartRWLocker(private val readPermits: Int = Int.MAX_VALUE, writeIsLocked: Boolean = false) { private val _readSemaphore = SmartSemaphore.Mutable(permits = readPermits, acquiredPermits = 0) - private val _writeMutex = SmartMutex.Mutable(locked = false) + private val _writeMutex = SmartMutex.Mutable(locked = writeIsLocked) val readSemaphore: SmartSemaphore.Immutable = _readSemaphore.immutable() val writeMutex: SmartMutex.Immutable = _writeMutex.immutable() diff --git a/coroutines/src/commonMain/kotlin/dev/inmo/micro_utils/coroutines/SmartSemaphore.kt b/coroutines/src/commonMain/kotlin/dev/inmo/micro_utils/coroutines/SmartSemaphore.kt index bfb29feb599..fd072f1e415 100644 --- a/coroutines/src/commonMain/kotlin/dev/inmo/micro_utils/coroutines/SmartSemaphore.kt +++ b/coroutines/src/commonMain/kotlin/dev/inmo/micro_utils/coroutines/SmartSemaphore.kt @@ -80,9 +80,9 @@ sealed interface SmartSemaphore { */ suspend fun tryAcquire(permits: Int = 1): Boolean { val checkedPermits = checkedPermits(permits) - return if (_permitsStateFlow.value >= checkedPermits) { + return if (_permitsStateFlow.value < checkedPermits) { internalChangesMutex.withLock { - if (_permitsStateFlow.value >= checkedPermits) { + if (_permitsStateFlow.value < checkedPermits) { _permitsStateFlow.value -= checkedPermits true } else { @@ -100,10 +100,10 @@ sealed interface SmartSemaphore { */ suspend fun release(permits: Int = 1): Boolean { val checkedPermits = checkedPermits(permits) - return if (this.permits - _permitsStateFlow.value > checkedPermits) { + return if (_permitsStateFlow.value < this.permits) { internalChangesMutex.withLock { - if (this.permits - _permitsStateFlow.value > checkedPermits) { - _permitsStateFlow.value += checkedPermits + if (_permitsStateFlow.value < this.permits) { + _permitsStateFlow.value = minOf(_permitsStateFlow.value + checkedPermits, this.permits) true } else { false diff --git a/coroutines/src/commonTest/kotlin/SmartRWLockerTests.kt b/coroutines/src/commonTest/kotlin/SmartRWLockerTests.kt new file mode 100644 index 00000000000..cef0ff93526 --- /dev/null +++ b/coroutines/src/commonTest/kotlin/SmartRWLockerTests.kt @@ -0,0 +1,60 @@ +import dev.inmo.micro_utils.coroutines.SmartRWLocker +import dev.inmo.micro_utils.coroutines.withReadAcquire +import dev.inmo.micro_utils.coroutines.withWriteLock +import kotlinx.coroutines.CoroutineStart +import kotlinx.coroutines.delay +import kotlinx.coroutines.joinAll +import kotlinx.coroutines.launch +import kotlinx.coroutines.sync.Mutex +import kotlinx.coroutines.sync.withLock +import kotlinx.coroutines.test.runTest +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertTrue + +class SmartRWLockerTests { + @Test + fun compositeTest() { + val locker = SmartRWLocker() + + val readAndWriteWorkers = 10 + runTest { + var started = 0 + var done = 0 + val doneMutex = Mutex() + val readWorkers = (0 until readAndWriteWorkers).map { + launch(start = CoroutineStart.LAZY) { + locker.withReadAcquire { + doneMutex.withLock { + started++ + } + delay(100L) + doneMutex.withLock { + done++ + } + } + } + } + + var doneWrites = 0 + + val writeWorkers = (0 until readAndWriteWorkers).map { + launch(start = CoroutineStart.LAZY) { + locker.withWriteLock { + assertTrue(done == readAndWriteWorkers || started == 0) + delay(10L) + doneWrites++ + } + } + } + readWorkers.forEach { it.start() } + writeWorkers.forEach { it.start() } + + readWorkers.joinAll() + writeWorkers.joinAll() + + assertEquals(expected = readAndWriteWorkers, actual = done) + assertEquals(expected = readAndWriteWorkers, actual = doneWrites) + } + } +} diff --git a/mppAndroidProject.gradle b/mppAndroidProject.gradle index 1817e67049d..7bd45386817 100644 --- a/mppAndroidProject.gradle +++ b/mppAndroidProject.gradle @@ -18,6 +18,7 @@ kotlin { dependencies { implementation kotlin('test-common') implementation kotlin('test-annotations-common') + implementation libs.kt.coroutines.test } } } diff --git a/mppJavaProject.gradle b/mppJavaProject.gradle index 5f6d62de0cb..4d2f2b6d5e4 100644 --- a/mppJavaProject.gradle +++ b/mppJavaProject.gradle @@ -22,6 +22,7 @@ kotlin { dependencies { implementation kotlin('test-common') implementation kotlin('test-annotations-common') + implementation libs.kt.coroutines.test } } diff --git a/mppJvmJsLinuxMingwProject.gradle b/mppJvmJsLinuxMingwProject.gradle index 1a5b08c7de5..a38f9435da1 100644 --- a/mppJvmJsLinuxMingwProject.gradle +++ b/mppJvmJsLinuxMingwProject.gradle @@ -28,6 +28,7 @@ kotlin { dependencies { implementation kotlin('test-common') implementation kotlin('test-annotations-common') + implementation libs.kt.coroutines.test } } diff --git a/mppProjectWithSerialization.gradle b/mppProjectWithSerialization.gradle index 8d0239816e5..2ed6bff423d 100644 --- a/mppProjectWithSerialization.gradle +++ b/mppProjectWithSerialization.gradle @@ -32,6 +32,7 @@ kotlin { dependencies { implementation kotlin('test-common') implementation kotlin('test-annotations-common') + implementation libs.kt.coroutines.test } } jvmTest { diff --git a/mppProjectWithSerializationAndCompose.gradle b/mppProjectWithSerializationAndCompose.gradle index e7b75fbd456..d05dcb9f3b6 100644 --- a/mppProjectWithSerializationAndCompose.gradle +++ b/mppProjectWithSerializationAndCompose.gradle @@ -31,6 +31,7 @@ kotlin { dependencies { implementation kotlin('test-common') implementation kotlin('test-annotations-common') + implementation libs.kt.coroutines.test } } jvmMain {