Compare commits

...

3 Commits

4 changed files with 30 additions and 28 deletions

View File

@ -1,14 +1,18 @@
package dev.inmo.micro_utils.ktor.client package dev.inmo.micro_utils.ktor.client
import dev.inmo.micro_utils.coroutines.runCatchingSafely
import dev.inmo.micro_utils.coroutines.safely import dev.inmo.micro_utils.coroutines.safely
import dev.inmo.micro_utils.ktor.common.* import dev.inmo.micro_utils.ktor.common.*
import io.ktor.client.HttpClient import io.ktor.client.HttpClient
import io.ktor.client.plugins.pluginOrNull
import io.ktor.client.plugins.websocket.WebSockets
import io.ktor.client.plugins.websocket.ws import io.ktor.client.plugins.websocket.ws
import io.ktor.client.request.HttpRequestBuilder import io.ktor.client.request.HttpRequestBuilder
import io.ktor.websocket.Frame import io.ktor.websocket.Frame
import io.ktor.websocket.readBytes import io.ktor.websocket.readBytes
import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.channelFlow import kotlinx.coroutines.flow.channelFlow
import kotlinx.coroutines.isActive
import kotlinx.serialization.DeserializationStrategy import kotlinx.serialization.DeserializationStrategy
/** /**
@ -17,43 +21,41 @@ import kotlinx.serialization.DeserializationStrategy
*/ */
inline fun <T> HttpClient.createStandardWebsocketFlow( inline fun <T> HttpClient.createStandardWebsocketFlow(
url: String, url: String,
crossinline checkReconnection: (Throwable?) -> Boolean = { true }, crossinline checkReconnection: suspend (Throwable?) -> Boolean = { true },
noinline requestBuilder: HttpRequestBuilder.() -> Unit = {}, noinline requestBuilder: HttpRequestBuilder.() -> Unit = {},
crossinline conversation: suspend (StandardKtorSerialInputData) -> T crossinline conversation: suspend (StandardKtorSerialInputData) -> T
): Flow<T> { ): Flow<T> {
pluginOrNull(WebSockets) ?: error("Plugin $WebSockets must be installed for using createStandardWebsocketFlow")
val correctedUrl = url.asCorrectWebSocketUrl val correctedUrl = url.asCorrectWebSocketUrl
return channelFlow { return channelFlow {
val producerScope = this@channelFlow
do { do {
val reconnect = try { val reconnect = runCatchingSafely {
safely { ws(correctedUrl, requestBuilder) {
ws(correctedUrl, requestBuilder) { for (received in incoming) {
for (received in incoming) { when (received) {
when (received) { is Frame.Binary -> send(conversation(received.data))
is Frame.Binary -> producerScope.send(conversation(received.readBytes())) else -> {
else -> { close()
producerScope.close() return@ws
return@ws
}
} }
} }
} }
} }
checkReconnection(null) checkReconnection(null)
} catch (e: Throwable) { }.getOrElse { e ->
checkReconnection(e).also { checkReconnection(e).also {
if (!it) { if (!it) {
producerScope.close(e) close(e)
} }
} }
} }
} while (reconnect) } while (reconnect && isActive)
if (!producerScope.isClosedForSend) {
safely( if (isActive) {
{ it.printStackTrace() } safely {
) { close()
producerScope.close()
} }
} }
} }
@ -65,8 +67,8 @@ inline fun <T> HttpClient.createStandardWebsocketFlow(
*/ */
inline fun <T> HttpClient.createStandardWebsocketFlow( inline fun <T> HttpClient.createStandardWebsocketFlow(
url: String, url: String,
crossinline checkReconnection: (Throwable?) -> Boolean = { true },
deserializer: DeserializationStrategy<T>, deserializer: DeserializationStrategy<T>,
crossinline checkReconnection: suspend (Throwable?) -> Boolean = { true },
serialFormat: StandardKtorSerialFormat = standardKtorSerialFormat, serialFormat: StandardKtorSerialFormat = standardKtorSerialFormat,
noinline requestBuilder: HttpRequestBuilder.() -> Unit = {}, noinline requestBuilder: HttpRequestBuilder.() -> Unit = {},
) = createStandardWebsocketFlow( ) = createStandardWebsocketFlow(

View File

@ -87,16 +87,16 @@ class UnifiedRequester(
fun <T> createStandardWebsocketFlow( fun <T> createStandardWebsocketFlow(
url: String, url: String,
checkReconnection: (Throwable?) -> Boolean, checkReconnection: suspend (Throwable?) -> Boolean,
deserializer: DeserializationStrategy<T>, deserializer: DeserializationStrategy<T>,
requestBuilder: HttpRequestBuilder.() -> Unit = {}, requestBuilder: HttpRequestBuilder.() -> Unit = {},
) = client.createStandardWebsocketFlow(url, checkReconnection, deserializer, serialFormat, requestBuilder) ) = client.createStandardWebsocketFlow(url, deserializer, checkReconnection, serialFormat, requestBuilder)
fun <T> createStandardWebsocketFlow( fun <T> createStandardWebsocketFlow(
url: String, url: String,
deserializer: DeserializationStrategy<T>, deserializer: DeserializationStrategy<T>,
requestBuilder: HttpRequestBuilder.() -> Unit = {}, requestBuilder: HttpRequestBuilder.() -> Unit = {},
) = createStandardWebsocketFlow(url, { true }, deserializer, requestBuilder) ) = createStandardWebsocketFlow(url, { true }, deserializer, requestBuilder)
} }
val defaultRequester = UnifiedRequester() val defaultRequester = UnifiedRequester()

View File

@ -15,13 +15,13 @@ import kotlinx.serialization.SerializationStrategy
fun <T> Route.includeWebsocketHandling( fun <T> Route.includeWebsocketHandling(
suburl: String, suburl: String,
flow: Flow<T>, flow: Flow<T>,
protocol: URLProtocol = URLProtocol.WS, protocol: URLProtocol? = null,
converter: suspend WebSocketServerSession.(T) -> StandardKtorSerialInputData? converter: suspend WebSocketServerSession.(T) -> StandardKtorSerialInputData?
) { ) {
application.apply { application.apply {
pluginOrNull(WebSockets) ?: install(WebSockets) pluginOrNull(WebSockets) ?: install(WebSockets)
} }
webSocket(suburl, protocol.name) { webSocket(suburl, protocol ?.name) {
safely { safely {
flow.collect { flow.collect {
converter(it) ?.let { data -> converter(it) ?.let { data ->
@ -37,7 +37,7 @@ fun <T> Route.includeWebsocketHandling(
flow: Flow<T>, flow: Flow<T>,
serializer: SerializationStrategy<T>, serializer: SerializationStrategy<T>,
serialFormat: StandardKtorSerialFormat = standardKtorSerialFormat, serialFormat: StandardKtorSerialFormat = standardKtorSerialFormat,
protocol: URLProtocol = URLProtocol.WS, protocol: URLProtocol? = null,
filter: (suspend WebSocketServerSession.(T) -> Boolean)? = null filter: (suspend WebSocketServerSession.(T) -> Boolean)? = null
) = includeWebsocketHandling( ) = includeWebsocketHandling(
suburl, suburl,

View File

@ -27,7 +27,7 @@ class UnifiedRouter(
suburl: String, suburl: String,
flow: Flow<T>, flow: Flow<T>,
serializer: SerializationStrategy<T>, serializer: SerializationStrategy<T>,
protocol: URLProtocol = URLProtocol.WS, protocol: URLProtocol? = null,
filter: (suspend WebSocketServerSession.(T) -> Boolean)? = null filter: (suspend WebSocketServerSession.(T) -> Boolean)? = null
) = includeWebsocketHandling(suburl, flow, serializer, serialFormat, protocol, filter) ) = includeWebsocketHandling(suburl, flow, serializer, serialFormat, protocol, filter)