fixes in websockets

This commit is contained in:
InsanusMokrassar 2022-04-28 19:43:52 +06:00
parent a13cc9e961
commit 9b9e7dd88f
4 changed files with 25 additions and 27 deletions

View File

@ -1,5 +1,6 @@
package dev.inmo.micro_utils.ktor.client package dev.inmo.micro_utils.ktor.client
import dev.inmo.micro_utils.coroutines.runCatchingSafely
import dev.inmo.micro_utils.coroutines.safely import dev.inmo.micro_utils.coroutines.safely
import dev.inmo.micro_utils.ktor.common.* import dev.inmo.micro_utils.ktor.common.*
import io.ktor.client.HttpClient import io.ktor.client.HttpClient
@ -11,6 +12,7 @@ import io.ktor.websocket.Frame
import io.ktor.websocket.readBytes import io.ktor.websocket.readBytes
import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.channelFlow import kotlinx.coroutines.flow.channelFlow
import kotlinx.coroutines.isActive
import kotlinx.serialization.DeserializationStrategy import kotlinx.serialization.DeserializationStrategy
/** /**
@ -19,7 +21,7 @@ import kotlinx.serialization.DeserializationStrategy
*/ */
inline fun <T> HttpClient.createStandardWebsocketFlow( inline fun <T> HttpClient.createStandardWebsocketFlow(
url: String, url: String,
crossinline checkReconnection: (Throwable?) -> Boolean = { true }, crossinline checkReconnection: suspend (Throwable?) -> Boolean = { true },
noinline requestBuilder: HttpRequestBuilder.() -> Unit = {}, noinline requestBuilder: HttpRequestBuilder.() -> Unit = {},
crossinline conversation: suspend (StandardKtorSerialInputData) -> T crossinline conversation: suspend (StandardKtorSerialInputData) -> T
): Flow<T> { ): Flow<T> {
@ -28,36 +30,32 @@ inline fun <T> HttpClient.createStandardWebsocketFlow(
val correctedUrl = url.asCorrectWebSocketUrl val correctedUrl = url.asCorrectWebSocketUrl
return channelFlow { return channelFlow {
val producerScope = this@channelFlow
do { do {
val reconnect = try { val reconnect = runCatchingSafely {
safely { ws(correctedUrl, requestBuilder) {
ws(correctedUrl, requestBuilder) { for (received in incoming) {
for (received in incoming) { when (received) {
when (received) { is Frame.Binary -> send(conversation(received.data))
is Frame.Binary -> producerScope.send(conversation(received.readBytes())) else -> {
else -> { close()
producerScope.close() return@ws
return@ws
}
} }
} }
} }
} }
checkReconnection(null) checkReconnection(null)
} catch (e: Throwable) { }.getOrElse { e ->
checkReconnection(e).also { checkReconnection(e).also {
if (!it) { if (!it) {
producerScope.close(e) close(e)
} }
} }
} }
} while (reconnect) } while (reconnect && isActive)
if (!producerScope.isClosedForSend) {
safely( if (isActive) {
{ it.printStackTrace() } safely {
) { close()
producerScope.close()
} }
} }
} }
@ -70,7 +68,7 @@ inline fun <T> HttpClient.createStandardWebsocketFlow(
inline fun <T> HttpClient.createStandardWebsocketFlow( inline fun <T> HttpClient.createStandardWebsocketFlow(
url: String, url: String,
deserializer: DeserializationStrategy<T>, deserializer: DeserializationStrategy<T>,
crossinline checkReconnection: (Throwable?) -> Boolean = { true }, crossinline checkReconnection: suspend (Throwable?) -> Boolean = { true },
serialFormat: StandardKtorSerialFormat = standardKtorSerialFormat, serialFormat: StandardKtorSerialFormat = standardKtorSerialFormat,
noinline requestBuilder: HttpRequestBuilder.() -> Unit = {}, noinline requestBuilder: HttpRequestBuilder.() -> Unit = {},
) = createStandardWebsocketFlow( ) = createStandardWebsocketFlow(

View File

@ -87,7 +87,7 @@ class UnifiedRequester(
fun <T> createStandardWebsocketFlow( fun <T> createStandardWebsocketFlow(
url: String, url: String,
checkReconnection: (Throwable?) -> Boolean, checkReconnection: suspend (Throwable?) -> Boolean,
deserializer: DeserializationStrategy<T>, deserializer: DeserializationStrategy<T>,
requestBuilder: HttpRequestBuilder.() -> Unit = {}, requestBuilder: HttpRequestBuilder.() -> Unit = {},
) = client.createStandardWebsocketFlow(url, deserializer, checkReconnection, serialFormat, requestBuilder) ) = client.createStandardWebsocketFlow(url, deserializer, checkReconnection, serialFormat, requestBuilder)
@ -96,7 +96,7 @@ class UnifiedRequester(
url: String, url: String,
deserializer: DeserializationStrategy<T>, deserializer: DeserializationStrategy<T>,
requestBuilder: HttpRequestBuilder.() -> Unit = {}, requestBuilder: HttpRequestBuilder.() -> Unit = {},
) = createStandardWebsocketFlow(url, { true }, deserializer, requestBuilder) ) = createStandardWebsocketFlow(url, { true }, deserializer, requestBuilder)
} }
val defaultRequester = UnifiedRequester() val defaultRequester = UnifiedRequester()

View File

@ -15,13 +15,13 @@ import kotlinx.serialization.SerializationStrategy
fun <T> Route.includeWebsocketHandling( fun <T> Route.includeWebsocketHandling(
suburl: String, suburl: String,
flow: Flow<T>, flow: Flow<T>,
protocol: URLProtocol = URLProtocol.WS, protocol: URLProtocol? = null,
converter: suspend WebSocketServerSession.(T) -> StandardKtorSerialInputData? converter: suspend WebSocketServerSession.(T) -> StandardKtorSerialInputData?
) { ) {
application.apply { application.apply {
pluginOrNull(WebSockets) ?: install(WebSockets) pluginOrNull(WebSockets) ?: install(WebSockets)
} }
webSocket(suburl, protocol.name) { webSocket(suburl, protocol ?.name) {
safely { safely {
flow.collect { flow.collect {
converter(it) ?.let { data -> converter(it) ?.let { data ->
@ -37,7 +37,7 @@ fun <T> Route.includeWebsocketHandling(
flow: Flow<T>, flow: Flow<T>,
serializer: SerializationStrategy<T>, serializer: SerializationStrategy<T>,
serialFormat: StandardKtorSerialFormat = standardKtorSerialFormat, serialFormat: StandardKtorSerialFormat = standardKtorSerialFormat,
protocol: URLProtocol = URLProtocol.WS, protocol: URLProtocol? = null,
filter: (suspend WebSocketServerSession.(T) -> Boolean)? = null filter: (suspend WebSocketServerSession.(T) -> Boolean)? = null
) = includeWebsocketHandling( ) = includeWebsocketHandling(
suburl, suburl,

View File

@ -27,7 +27,7 @@ class UnifiedRouter(
suburl: String, suburl: String,
flow: Flow<T>, flow: Flow<T>,
serializer: SerializationStrategy<T>, serializer: SerializationStrategy<T>,
protocol: URLProtocol = URLProtocol.WS, protocol: URLProtocol? = null,
filter: (suspend WebSocketServerSession.(T) -> Boolean)? = null filter: (suspend WebSocketServerSession.(T) -> Boolean)? = null
) = includeWebsocketHandling(suburl, flow, serializer, serialFormat, protocol, filter) ) = includeWebsocketHandling(suburl, flow, serializer, serialFormat, protocol, filter)