diff --git a/ktor/client/src/commonMain/kotlin/com/insanusmokrassar/postssystem/ktor/client/FlowsWebsockets.kt b/ktor/client/src/commonMain/kotlin/com/insanusmokrassar/postssystem/ktor/client/FlowsWebsockets.kt index 036cbc6b..ee6dadda 100644 --- a/ktor/client/src/commonMain/kotlin/com/insanusmokrassar/postssystem/ktor/client/FlowsWebsockets.kt +++ b/ktor/client/src/commonMain/kotlin/com/insanusmokrassar/postssystem/ktor/client/FlowsWebsockets.kt @@ -4,41 +4,60 @@ import com.insanusmokrassar.postssystem.ktor.* import com.insanusmokrassar.postssystem.utils.common.safely import io.ktor.client.HttpClient import io.ktor.client.features.websocket.ws -import io.ktor.client.request.url -import io.ktor.http.HttpMethod import io.ktor.http.cio.websocket.* import kotlinx.coroutines.flow.* -import kotlinx.coroutines.isActive +/** + * @param checkReconnection This lambda will be called when it is required to reconnect to websocket to establish + * connection. Must return true in case if must be reconnected. By default always reconnecting + */ inline fun createStandardWebsocketFlow( client: HttpClient, url: String, + crossinline checkReconnection: (Throwable?) -> Boolean = { true }, crossinline conversation: suspend (ByteArray) -> T ): Flow { val correctedUrl = url.asCorrectWebSocketUrl return channelFlow { val producerScope = this - safely( - { - producerScope.close() - throw it - } - ) { - client.ws( - correctedUrl - ) { - while (true) { - when (val received = incoming.receive()) { - is Frame.Binary -> producerScope.send( - conversation(received.readBytes()) - ) - else -> { - producerScope.close() - return@ws + do { + val reconnect = try { + safely( + { + throw it + } + ) { + client.ws( + correctedUrl + ) { + while (true) { + when (val received = incoming.receive()) { + is Frame.Binary -> producerScope.send( + conversation(received.readBytes()) + ) + else -> { + producerScope.close() + return@ws + } + } } } } + checkReconnection(null) + } catch (e: Throwable) { + checkReconnection(e).also { + if (!it) { + producerScope.close(e) + } + } + } + } while (reconnect) + if (!producerScope.isClosedForSend) { + safely( + { /* do nothing */ } + ) { + producerScope.close() } } } diff --git a/ktor/tests/src/test/kotlin/com/insanusmokrassar/postssystem/ktor/tests/WebsocketsTest.kt b/ktor/tests/src/test/kotlin/com/insanusmokrassar/postssystem/ktor/tests/WebsocketsTest.kt index 04ca92d8..8cbb0986 100644 --- a/ktor/tests/src/test/kotlin/com/insanusmokrassar/postssystem/ktor/tests/WebsocketsTest.kt +++ b/ktor/tests/src/test/kotlin/com/insanusmokrassar/postssystem/ktor/tests/WebsocketsTest.kt @@ -47,7 +47,8 @@ class WebsocketsTest { } val incomingWebsocketFlow = createStandardWebsocketFlow( client, - "$serverUrl/$suburl" + "$serverUrl/$suburl", + { false } // always skip reconnection ) { standardKtorSerializer.load(Int.serializer(), it) }