refactor: split Gateway into Handler and Session
This commit is contained in:
parent
214efd7417
commit
b04d75df99
6 changed files with 108 additions and 89 deletions
|
|
@ -0,0 +1,38 @@
|
|||
package moe.lava.neon.core.api.gateway
|
||||
|
||||
import co.touchlab.kermit.Logger
|
||||
import dev.zacsweers.metro.Inject
|
||||
import kotlinx.serialization.ExperimentalSerializationApi
|
||||
import moe.lava.neon.core.di.EventHandlerGraph
|
||||
import moe.lava.neon.core.repository.AuthRepository
|
||||
|
||||
@Inject
|
||||
class GatewayHandler(
|
||||
private val auth: AuthRepository,
|
||||
private val handlers: EventHandlerGraph,
|
||||
) {
|
||||
private val logger = Logger.withTag("neon.core.api.gateway/handler")
|
||||
private var session: GatewaySession? = null
|
||||
|
||||
@OptIn(ExperimentalSerializationApi::class)
|
||||
suspend fun connect() {
|
||||
if (session != null) {
|
||||
logger.w(Throwable()) { "Attempted to connect, but client already connected, ignoring..." }
|
||||
return
|
||||
}
|
||||
val token = auth.token
|
||||
?: throw IllegalStateException("Tried to connect to gateway with no token")
|
||||
|
||||
session = GatewaySession.start(
|
||||
token = token,
|
||||
eventHandlers = handlers,
|
||||
onDestroy = { session = null }
|
||||
)
|
||||
}
|
||||
|
||||
suspend fun disconnect() {
|
||||
val session = session
|
||||
?: throw IllegalStateException("Tried disconnecting with no session")
|
||||
session.close()
|
||||
}
|
||||
}
|
||||
|
|
@ -1,7 +1,6 @@
|
|||
package moe.lava.neon.core.api.gateway
|
||||
|
||||
import co.touchlab.kermit.Logger
|
||||
import dev.zacsweers.metro.Inject
|
||||
import io.ktor.client.HttpClient
|
||||
import io.ktor.client.plugins.cookies.HttpCookies
|
||||
import io.ktor.client.plugins.websocket.DefaultClientWebSocketSession
|
||||
|
|
@ -15,46 +14,44 @@ import io.ktor.websocket.readText
|
|||
import io.ktor.websocket.send
|
||||
import kotlinx.coroutines.CoroutineScope
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.cancel
|
||||
import kotlinx.coroutines.delay
|
||||
import kotlinx.coroutines.flow.consumeAsFlow
|
||||
import kotlinx.coroutines.flow.launchIn
|
||||
import kotlinx.coroutines.flow.onCompletion
|
||||
import kotlinx.coroutines.flow.onEach
|
||||
import kotlinx.coroutines.isActive
|
||||
import kotlinx.coroutines.launch
|
||||
import kotlinx.serialization.ExperimentalSerializationApi
|
||||
import moe.lava.neon.core.api.ApiConstants
|
||||
import moe.lava.neon.core.di.GatewayHandlerGraph
|
||||
import moe.lava.neon.core.repository.AuthRepository
|
||||
import moe.lava.neon.core.api.ApiConstants.json
|
||||
import moe.lava.neon.core.di.EventHandlerGraph
|
||||
import kotlin.random.Random
|
||||
import kotlin.time.Duration
|
||||
import kotlin.time.Duration.Companion.milliseconds
|
||||
|
||||
@Inject
|
||||
class Gateway(
|
||||
private val auth: AuthRepository,
|
||||
private val handlers: GatewayHandlerGraph,
|
||||
private val logger = Logger.withTag("neon.core.api.gateway/session")
|
||||
|
||||
class GatewaySession private constructor(
|
||||
private var ws: DefaultClientWebSocketSession,
|
||||
private val token: String,
|
||||
private val handlers: EventHandlerGraph,
|
||||
private val scope: CoroutineScope,
|
||||
private val onDestroy: () -> Unit,
|
||||
) {
|
||||
private val logger = Logger.withTag("neon.core.api/gateway")
|
||||
private val scope = CoroutineScope(Dispatchers.IO)
|
||||
private var ws: DefaultClientWebSocketSession? = null
|
||||
private var lastSeq: Int? = null
|
||||
private var missedHeartbeats = 0
|
||||
|
||||
private val json = ApiConstants.json
|
||||
|
||||
private var seq: Int? = null
|
||||
|
||||
@OptIn(ExperimentalSerializationApi::class)
|
||||
suspend fun connect() {
|
||||
if (ws != null) {
|
||||
logger.w(Throwable()) { "Attempted to connect, but client already connected, ignoring..." }
|
||||
return
|
||||
}
|
||||
if (auth.token == null) {
|
||||
throw IllegalStateException("Tried to connect to gateway with no token")
|
||||
}
|
||||
val ws = HttpClient {
|
||||
companion object {
|
||||
suspend fun start(
|
||||
token: String,
|
||||
eventHandlers: EventHandlerGraph,
|
||||
client: HttpClient = HttpClient {
|
||||
install(HttpCookies)
|
||||
install(WebSockets)
|
||||
}.webSocketSession("wss://gateway.discord.gg/") {
|
||||
},
|
||||
scope: CoroutineScope = CoroutineScope(Dispatchers.IO),
|
||||
onDestroy: () -> Unit,
|
||||
): GatewaySession {
|
||||
val ws = client.webSocketSession("wss://gateway.discord.gg/") {
|
||||
userAgent(ApiConstants.gatewayUserAgent)
|
||||
url {
|
||||
parameter("encoding", "json")
|
||||
|
|
@ -62,10 +59,14 @@ class Gateway(
|
|||
// parameter("compress", "zstd-stream")
|
||||
}
|
||||
}
|
||||
this.ws = ws
|
||||
|
||||
return GatewaySession(ws, token, eventHandlers, scope, onDestroy)
|
||||
}
|
||||
}
|
||||
|
||||
init {
|
||||
ws.incoming.consumeAsFlow()
|
||||
.onCompletion { cleanup(it) }
|
||||
.onCompletion { close(it) }
|
||||
.onEach { frame ->
|
||||
if (frame !is Frame.Text)
|
||||
// if (frame !is Frame.Text && frame !is Frame.Binary)
|
||||
|
|
@ -74,9 +75,9 @@ class Gateway(
|
|||
logger.d { "Received payload ${frame.readText()}" }
|
||||
|
||||
val raw = json.decodeFromString<Payload.Unknown>(frame.readText())
|
||||
val seq = this.seq ?: 0
|
||||
val seq = this.lastSeq ?: 0
|
||||
if (seq + 1 == raw.s) {
|
||||
this.seq = raw.s
|
||||
this.lastSeq = raw.s
|
||||
} else if (raw.s != null) {
|
||||
resume(ResumeReason.SkippedSequence(raw.s))
|
||||
return@onEach
|
||||
|
|
@ -90,63 +91,37 @@ class Gateway(
|
|||
.launchIn(scope)
|
||||
}
|
||||
|
||||
suspend fun handlePayload(e: Payload.Incoming<*>) {
|
||||
logger.d { e.toString() }
|
||||
when (val event = e.d) {
|
||||
private suspend fun handlePayload(payload: Payload.Incoming<*>) {
|
||||
logger.d { payload.toString() }
|
||||
when (val event = payload.d) {
|
||||
is Event.Hello -> handleHello(event)
|
||||
is Event.Ready -> handlers.ready.handle(event)
|
||||
is Event.Heartbeat -> {}
|
||||
is Event.HeartbeatAck -> { missedBeats -= 1 }
|
||||
is Event.HeartbeatAck -> { missedHeartbeats -= 1 }
|
||||
}
|
||||
}
|
||||
|
||||
suspend fun handleUnknownPayload(e: Payload.Unknown) {
|
||||
logger.w { "Unknown payload $e" }
|
||||
private suspend fun handleUnknownPayload(payload: Payload.Unknown) {
|
||||
logger.w { "Unknown payload $payload" }
|
||||
}
|
||||
|
||||
suspend fun handleHello(e: Event.Hello) {
|
||||
val token = auth.token
|
||||
?: throw IllegalStateException("Token missing between connection and hello, cannot send Identify")
|
||||
|
||||
private suspend fun handleHello(e: Event.Hello) {
|
||||
Event.Identify(token = token).pack().send()
|
||||
|
||||
val interval = e.heartbeatInterval.milliseconds
|
||||
scope.launch {
|
||||
startHeartbeat(interval)
|
||||
}
|
||||
}
|
||||
|
||||
private var missedBeats = 0
|
||||
private suspend fun startHeartbeat(interval: Duration) {
|
||||
val ws = this.ws
|
||||
?: throw IllegalStateException("Ws missing whilst starting heartbeat")
|
||||
|
||||
delay(interval * Random.nextDouble())
|
||||
while (this@Gateway.ws == ws) {
|
||||
if (missedBeats >= 1) {
|
||||
while (true) {
|
||||
if (missedHeartbeats >= 1) {
|
||||
resume(ResumeReason.MissedHeartbeat)
|
||||
break
|
||||
}
|
||||
Event.QoSHeartbeat(this@Gateway.seq).pack().send()
|
||||
missedBeats += 1
|
||||
Event.QoSHeartbeat(lastSeq).pack().send()
|
||||
missedHeartbeats += 1
|
||||
delay(interval)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: handle resuming, etc..
|
||||
suspend fun cleanup(error: Throwable? = null) {
|
||||
logger.d(error) { "Websocket connection closed, cleaning up..." }
|
||||
this.ws = null
|
||||
}
|
||||
|
||||
suspend fun disconnect() {
|
||||
val ws = ws
|
||||
if (ws == null) {
|
||||
logger.w(Throwable()) { "Attempted to disconnect, but client was not connected" }
|
||||
return
|
||||
}
|
||||
this.ws = null
|
||||
ws.close()
|
||||
}
|
||||
|
||||
private sealed class ResumeReason {
|
||||
|
|
@ -160,7 +135,7 @@ class Gateway(
|
|||
is ResumeReason.MissedHeartbeat ->
|
||||
"heartbeat missed"
|
||||
is ResumeReason.SkippedSequence ->
|
||||
"payloads skipped one sequence (expected: $seq, actual: ${reason.next})"
|
||||
"payloads skipped one sequence (expected: $lastSeq, actual: ${reason.next})"
|
||||
is ResumeReason.CloseCode ->
|
||||
"closed with code ${reason.code}"
|
||||
null ->
|
||||
|
|
@ -171,9 +146,15 @@ class Gateway(
|
|||
// TODO
|
||||
}
|
||||
|
||||
// TODO: handle resuming, etc..
|
||||
suspend fun close(error: Throwable? = null) {
|
||||
logger.d(error) { "Websocket connection closed, cleaning up..." }
|
||||
ws.close()
|
||||
if (scope.isActive) scope.cancel()
|
||||
onDestroy()
|
||||
}
|
||||
|
||||
private suspend inline fun <reified T : Event.Outgoing> Payload.Outgoing<T>.send() {
|
||||
val ws = ws
|
||||
?: throw IllegalStateException("Tried to send with no connection")
|
||||
logger.d { "Sending payload $this" }
|
||||
logger.d { "Raw: ${json.encodeToString(this)}" }
|
||||
ws.send(json.encodeToString(this))
|
||||
|
|
@ -66,7 +66,7 @@ sealed interface Event {
|
|||
// 11
|
||||
@JvmInline
|
||||
@Serializable
|
||||
value class HeartbeatAck(private val nothing: Nothing?) : Incoming, Outgoing
|
||||
value class HeartbeatAck(private val nothing: Nothing?) : Incoming
|
||||
|
||||
// 10
|
||||
@Serializable
|
||||
|
|
|
|||
|
|
@ -14,5 +14,5 @@ interface AppGraph {
|
|||
val auth: AuthRepository
|
||||
val users: UserRepository
|
||||
|
||||
val gatewayHandlers: GatewayHandlerGraph
|
||||
val gatewayHandlers: EventHandlerGraph
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,6 +7,6 @@ import moe.lava.neon.core.api.gateway.handlers.ReadyHandler
|
|||
|
||||
@GraphExtension
|
||||
@ContributesTo(AppScope::class)
|
||||
interface GatewayHandlerGraph {
|
||||
interface EventHandlerGraph {
|
||||
val ready: ReadyHandler
|
||||
}
|
||||
|
|
@ -25,7 +25,7 @@ import dev.zacsweers.metro.Inject
|
|||
import dev.zacsweers.metrox.viewmodel.ViewModelKey
|
||||
import dev.zacsweers.metrox.viewmodel.metroViewModel
|
||||
import kotlinx.coroutines.launch
|
||||
import moe.lava.neon.core.api.gateway.Gateway
|
||||
import moe.lava.neon.core.api.gateway.GatewayHandler
|
||||
import moe.lava.neon.core.repository.AuthRepository
|
||||
import moe.lava.neon.resources.Res
|
||||
import moe.lava.neon.resources.compose_multiplatform
|
||||
|
|
@ -81,7 +81,7 @@ fun Sample(onRequestLogout: () -> Unit) {
|
|||
@ContributesIntoMap(AppScope::class)
|
||||
class SampleViewModel(
|
||||
private val auth: AuthRepository,
|
||||
private val gateway: Gateway,
|
||||
private val gateway: GatewayHandler,
|
||||
) : ViewModel() {
|
||||
val token get() = auth.token
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue