potentially first version of SmartKeyRWLocker

This commit is contained in:
2025-03-23 11:16:45 +06:00
parent 4c9e435df8
commit 761070b9b7
7 changed files with 182 additions and 28 deletions

View File

@@ -1,55 +1,143 @@
package dev.inmo.micro_utils.coroutines package dev.inmo.micro_utils.coroutines
import kotlinx.coroutines.CancellationException
import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock import kotlinx.coroutines.sync.withLock
import kotlin.contracts.ExperimentalContracts import kotlin.contracts.ExperimentalContracts
import kotlin.contracts.InvocationKind import kotlin.contracts.InvocationKind
import kotlin.contracts.contract import kotlin.contracts.contract
class SmartKeyRWLocker<T>( class SmartKeyRWLocker<T>(
globalLockerReadPermits: Int = Int.MAX_VALUE,
globalLockerWriteIsLocked: Boolean = false,
private val perKeyReadPermits: Int = Int.MAX_VALUE private val perKeyReadPermits: Int = Int.MAX_VALUE
) { ) {
private val internalRWLocker = SmartRWLocker() private val globalRWLocker: SmartRWLocker = SmartRWLocker(
readPermits = globalLockerReadPermits,
writeIsLocked = globalLockerWriteIsLocked
)
private val lockers = mutableMapOf<T, SmartRWLocker>() private val lockers = mutableMapOf<T, SmartRWLocker>()
private val lockersMutex = Mutex()
private val lockersWritingLocker = SmartSemaphore.Mutable(Int.MAX_VALUE)
private val globalWritingLocker = SmartSemaphore.Mutable(Int.MAX_VALUE)
private fun allocateLockerWithoutLock(key: T) = lockers.getOrPut(key) { private fun allocateLockerWithoutLock(key: T) = lockers.getOrPut(key) {
SmartRWLocker(perKeyReadPermits) SmartRWLocker(perKeyReadPermits)
} }
private suspend fun allocateLocker(key: T) = lockersMutex.withLock {
lockers.getOrPut(key) {
SmartRWLocker(perKeyReadPermits)
}
}
suspend fun writeMutex(key: T): SmartMutex.Immutable = internalRWLocker.withReadAcquire { suspend fun writeMutex(key: T): SmartMutex.Immutable = globalRWLocker.withReadAcquire {
allocateLockerWithoutLock(key).writeMutex allocateLockerWithoutLock(key).writeMutex
} }
suspend fun readSemaphore(key: T): SmartSemaphore.Immutable = internalRWLocker.withReadAcquire { suspend fun readSemaphore(key: T): SmartSemaphore.Immutable = globalRWLocker.withReadAcquire {
allocateLockerWithoutLock(key).readSemaphore allocateLockerWithoutLock(key).readSemaphore
} }
fun writeMutexOrNull(key: T): SmartMutex.Immutable? = lockers[key] ?.writeMutex fun writeMutexOrNull(key: T): SmartMutex.Immutable? = lockers[key] ?.writeMutex
fun readSemaphoreOrNull(key: T): SmartSemaphore.Immutable? = lockers[key] ?.readSemaphore fun readSemaphoreOrNull(key: T): SmartSemaphore.Immutable? = lockers[key] ?.readSemaphore
fun writeMutex(): SmartMutex.Immutable = globalRWLocker.writeMutex
fun readSemaphore(): SmartSemaphore.Immutable = globalRWLocker.readSemaphore
suspend fun acquireRead() {
globalWritingLocker.acquire()
try {
lockersWritingLocker.waitReleaseAll()
globalRWLocker.acquireRead()
} catch (e: CancellationException) {
globalWritingLocker.release()
throw e
}
}
suspend fun releaseRead(): Boolean {
globalWritingLocker.release()
return globalRWLocker.releaseRead()
}
suspend fun lockWrite() {
globalRWLocker.lockWrite()
}
suspend fun unlockWrite(): Boolean {
return globalRWLocker.unlockWrite()
}
fun isWriteLocked(): Boolean = globalRWLocker.writeMutex.isLocked == true
suspend fun acquireRead(key: T) { suspend fun acquireRead(key: T) {
internalRWLocker.withReadAcquire { globalRWLocker.acquireRead()
val locker = allocateLockerWithoutLock(key) val locker = allocateLocker(key)
try {
locker.acquireRead() locker.acquireRead()
} catch (e: CancellationException) {
globalRWLocker.releaseRead()
throw e
} }
} }
suspend fun releaseRead(key: T): Boolean { suspend fun releaseRead(key: T): Boolean {
return internalRWLocker.withReadAcquire { val locker = allocateLocker(key)
lockers[key] return locker.releaseRead() && globalRWLocker.releaseRead()
} ?.releaseRead() == true
} }
suspend fun lockWrite(key: T) { suspend fun lockWrite(key: T) {
internalRWLocker.withWriteLock { globalWritingLocker.withAcquire(globalWritingLocker.maxPermits) {
val locker = allocateLockerWithoutLock(key) lockersWritingLocker.acquire()
locker.lockWrite() }
try {
globalRWLocker.acquireRead()
try {
val locker = allocateLocker(key)
locker.lockWrite()
} catch (e: CancellationException) {
globalRWLocker.releaseRead()
throw e
}
} catch (e: CancellationException) {
lockersWritingLocker.release()
throw e
} }
} }
suspend fun unlockWrite(key: T): Boolean { suspend fun unlockWrite(key: T): Boolean {
return internalRWLocker.withWriteLock { val locker = allocateLocker(key)
lockers[key] return (locker.unlockWrite() && globalRWLocker.releaseRead()).also {
} ?.unlockWrite() == true if (it) {
lockersWritingLocker.release()
}
}
} }
fun isWriteLocked(key: T): Boolean = lockers[key] ?.writeMutex ?.isLocked == true fun isWriteLocked(key: T): Boolean = lockers[key] ?.writeMutex ?.isLocked == true
} }
@OptIn(ExperimentalContracts::class)
suspend inline fun <T, R> SmartKeyRWLocker<T>.withReadAcquire(action: () -> R): R {
contract {
callsInPlace(action, InvocationKind.EXACTLY_ONCE)
}
acquireRead()
try {
return action()
} finally {
releaseRead()
}
}
@OptIn(ExperimentalContracts::class)
suspend inline fun <T, R> SmartKeyRWLocker<T>.withWriteLock(action: () -> R): R {
contract {
callsInPlace(action, InvocationKind.EXACTLY_ONCE)
}
lockWrite()
try {
return action()
} finally {
unlockWrite()
}
}
@OptIn(ExperimentalContracts::class) @OptIn(ExperimentalContracts::class)
suspend inline fun <T, R> SmartKeyRWLocker<T>.withReadAcquire(key: T, action: () -> R): R { suspend inline fun <T, R> SmartKeyRWLocker<T>.withReadAcquire(key: T, action: () -> R): R {
contract { contract {

View File

@@ -1,5 +1,6 @@
package dev.inmo.micro_utils.coroutines package dev.inmo.micro_utils.coroutines
import kotlinx.coroutines.CancellationException
import kotlin.contracts.ExperimentalContracts import kotlin.contracts.ExperimentalContracts
import kotlin.contracts.InvocationKind import kotlin.contracts.InvocationKind
import kotlin.contracts.contract import kotlin.contracts.contract
@@ -39,7 +40,12 @@ class SmartRWLocker(private val readPermits: Int = Int.MAX_VALUE, writeIsLocked:
*/ */
suspend fun lockWrite() { suspend fun lockWrite() {
_writeMutex.lock() _writeMutex.lock()
_readSemaphore.acquire(readPermits) try {
_readSemaphore.acquire(readPermits)
} catch (e: CancellationException) {
_writeMutex.unlock()
throw e
}
} }
/** /**

View File

@@ -24,6 +24,7 @@ import kotlin.contracts.contract
* [Mutable] creator * [Mutable] creator
*/ */
sealed interface SmartSemaphore { sealed interface SmartSemaphore {
val maxPermits: Int
val permitsStateFlow: StateFlow<Int> val permitsStateFlow: StateFlow<Int>
/** /**
@@ -36,7 +37,7 @@ sealed interface SmartSemaphore {
/** /**
* Immutable variant of [SmartSemaphore]. In fact will depend on the owner of [permitsStateFlow] * Immutable variant of [SmartSemaphore]. In fact will depend on the owner of [permitsStateFlow]
*/ */
class Immutable(override val permitsStateFlow: StateFlow<Int>) : SmartSemaphore class Immutable(override val permitsStateFlow: StateFlow<Int>, override val maxPermits: Int) : SmartSemaphore
/** /**
* Mutable variant of [SmartSemaphore]. With that variant you may [lock] and [unlock]. Besides, you may create * Mutable variant of [SmartSemaphore]. With that variant you may [lock] and [unlock]. Besides, you may create
@@ -44,15 +45,16 @@ sealed interface SmartSemaphore {
* *
* @param locked Preset state of [freePermits] and its internal [_freePermitsStateFlow] * @param locked Preset state of [freePermits] and its internal [_freePermitsStateFlow]
*/ */
class Mutable(private val permits: Int, acquiredPermits: Int = 0) : SmartSemaphore { class Mutable(permits: Int, acquiredPermits: Int = 0) : SmartSemaphore {
override val maxPermits: Int = permits
private val _freePermitsStateFlow = SpecialMutableStateFlow<Int>(permits - acquiredPermits) private val _freePermitsStateFlow = SpecialMutableStateFlow<Int>(permits - acquiredPermits)
override val permitsStateFlow: StateFlow<Int> = _freePermitsStateFlow.asStateFlow() override val permitsStateFlow: StateFlow<Int> = _freePermitsStateFlow.asStateFlow()
private val internalChangesMutex = Mutex(false) private val internalChangesMutex = Mutex(false)
fun immutable() = Immutable(permitsStateFlow) fun immutable() = Immutable(permitsStateFlow, maxPermits)
private fun checkedPermits(permits: Int) = permits.coerceIn(1 .. this.permits) private fun checkedPermits(permits: Int) = permits.coerceIn(1 .. this.maxPermits)
/** /**
* Holds call until this [SmartSemaphore] will be re-locked. That means that current method will * Holds call until this [SmartSemaphore] will be re-locked. That means that current method will
@@ -126,10 +128,10 @@ sealed interface SmartSemaphore {
*/ */
suspend fun release(permits: Int = 1): Boolean { suspend fun release(permits: Int = 1): Boolean {
val checkedPermits = checkedPermits(permits) val checkedPermits = checkedPermits(permits)
return if (_freePermitsStateFlow.value < this.permits) { return if (_freePermitsStateFlow.value < this.maxPermits) {
internalChangesMutex.withLock { internalChangesMutex.withLock {
if (_freePermitsStateFlow.value < this.permits) { if (_freePermitsStateFlow.value < this.maxPermits) {
_freePermitsStateFlow.value = minOf(_freePermitsStateFlow.value + checkedPermits, this.permits) _freePermitsStateFlow.value = minOf(_freePermitsStateFlow.value + checkedPermits, this.maxPermits)
true true
} else { } else {
false false
@@ -166,3 +168,4 @@ suspend inline fun <T> SmartSemaphore.Mutable.withAcquire(permits: Int = 1, acti
* the fact that some other parties may lock it again * the fact that some other parties may lock it again
*/ */
suspend fun SmartSemaphore.waitRelease(permits: Int = 1) = permitsStateFlow.first { it >= permits } suspend fun SmartSemaphore.waitRelease(permits: Int = 1) = permitsStateFlow.first { it >= permits }
suspend fun SmartSemaphore.waitReleaseAll() = permitsStateFlow.first { it == maxPermits }

View File

@@ -63,13 +63,54 @@ class SmartKeyRWLockerTests {
assertFails { assertFails {
realWithTimeout(13.milliseconds) { locker.lockWrite() } realWithTimeout(13.milliseconds) { locker.lockWrite() }
} }
val readPermitsBeforeLock = locker.readSemaphore().freePermits
realWithTimeout(1.seconds) { locker.acquireRead() } realWithTimeout(1.seconds) { locker.acquireRead() }
locker.releaseRead() locker.releaseRead()
assertTrue { locker.readSemaphore().freePermits == Int.MAX_VALUE } assertEquals(readPermitsBeforeLock, locker.readSemaphore().freePermits)
locker.releaseRead(it) locker.releaseRead(it)
} }
assertTrue { locker.readSemaphore().freePermits == Int.MAX_VALUE }
realWithTimeout(1.seconds) { locker.lockWrite() }
assertFails {
realWithTimeout(13.milliseconds) { locker.acquireRead() }
}
assertTrue { locker.unlockWrite() }
assertTrue { locker.readSemaphore().freePermits == Int.MAX_VALUE }
}
@Test
fun writesBlockingGlobalWrite() = runTest {
val locker = SmartKeyRWLocker<String>()
val testKeys = (0 until 100).map { "test$it" }
for (i in testKeys.indices) {
val it = testKeys[i]
locker.lockWrite(it)
val previous = testKeys.take(i)
val next = testKeys.drop(i + 1)
previous.forEach {
assertTrue { locker.writeMutexOrNull(it) ?.isLocked == true }
}
next.forEach {
assertTrue { locker.writeMutexOrNull(it) ?.isLocked != true }
}
}
for (i in testKeys.indices) {
val it = testKeys[i]
assertFails { realWithTimeout(13.milliseconds) { locker.lockWrite() } }
val readPermitsBeforeLock = locker.readSemaphore().freePermits
assertFails { realWithTimeout(13.milliseconds) { locker.acquireRead() } }
assertEquals(readPermitsBeforeLock, locker.readSemaphore().freePermits)
locker.unlockWrite(it)
}
assertTrue { locker.readSemaphore().freePermits == Int.MAX_VALUE }
realWithTimeout(1.seconds) { locker.lockWrite() } realWithTimeout(1.seconds) { locker.lockWrite() }
assertFails { assertFails {
realWithTimeout(13.milliseconds) { locker.acquireRead() } realWithTimeout(13.milliseconds) { locker.acquireRead() }

View File

@@ -6,7 +6,10 @@ import kotlinx.coroutines.sync.withLock
import kotlinx.coroutines.test.runTest import kotlinx.coroutines.test.runTest
import kotlin.test.Test import kotlin.test.Test
import kotlin.test.assertEquals import kotlin.test.assertEquals
import kotlin.test.assertFails
import kotlin.test.assertFalse
import kotlin.test.assertTrue import kotlin.test.assertTrue
import kotlin.time.Duration.Companion.seconds
class SmartRWLockerTests { class SmartRWLockerTests {
@Test @Test
@@ -148,4 +151,17 @@ class SmartRWLockerTests {
assertEquals(false, locker.writeMutex.isLocked) assertEquals(false, locker.writeMutex.isLocked)
} }
} }
@Test
fun exceptionOnLockingWillNotLockLocker() = runTest {
val locker = SmartRWLocker()
locker.acquireRead()
assertFails {
realWithTimeout(1.seconds) {
locker.lockWrite()
}
}
assertFalse { locker.writeMutex.isLocked }
}
} }

View File

@@ -19,7 +19,7 @@ class RollbackContext<T> internal constructor (
* *
* @param rollback Will be called if * @param rollback Will be called if
*/ */
suspend fun <T> TransactionsDSL.rollbackableOperation( suspend fun <T> TransactionsDSL.rollableBackOperation(
rollback: suspend RollbackContext<T>.() -> Unit, rollback: suspend RollbackContext<T>.() -> Unit,
action: suspend () -> T action: suspend () -> T
): T { ): T {
@@ -34,7 +34,7 @@ suspend fun <T> TransactionsDSL.rollbackableOperation(
} }
/** /**
* Starts transaction with opportunity to add actions [rollbackableOperation]. How to use: * Starts transaction with opportunity to add actions [rollableBackOperation]. How to use:
* *
* ```kotlin * ```kotlin
* doSuspendTransaction { * doSuspendTransaction {

View File

@@ -1,5 +1,5 @@
import dev.inmo.micro_utils.transactions.doSuspendTransaction import dev.inmo.micro_utils.transactions.doSuspendTransaction
import dev.inmo.micro_utils.transactions.rollbackableOperation import dev.inmo.micro_utils.transactions.rollableBackOperation
import kotlinx.coroutines.test.runTest import kotlinx.coroutines.test.runTest
import kotlin.test.Test import kotlin.test.Test
import kotlin.test.assertEquals import kotlin.test.assertEquals
@@ -19,7 +19,7 @@ class TransactionsDSLTests {
val actionResult = doSuspendTransaction { val actionResult = doSuspendTransaction {
dataCollections.forEachIndexed { i, _ -> dataCollections.forEachIndexed { i, _ ->
val resultData = rollbackableOperation({ val resultData = rollableBackOperation({
dataCollections[i] = actionResult.copy(second = true) dataCollections[i] = actionResult.copy(second = true)
}) { }) {
val result = dataCollections[i] val result = dataCollections[i]
@@ -56,7 +56,7 @@ class TransactionsDSLTests {
val actionResult = doSuspendTransaction { val actionResult = doSuspendTransaction {
dataCollections.forEachIndexed { i, _ -> dataCollections.forEachIndexed { i, _ ->
val resultData = rollbackableOperation({ val resultData = rollableBackOperation({
assertTrue(error === this.error) assertTrue(error === this.error)
dataCollections[i] = actionResult.copy(second = true) dataCollections[i] = actionResult.copy(second = true)
}) { }) {