refactor: split Gateway into Handler and Session

This commit is contained in:
Cilly Leang 2026-01-27 17:37:25 +11:00
parent 214efd7417
commit b04d75df99
Signed by: cilly
GPG key ID: 6500251E087653C9
6 changed files with 108 additions and 89 deletions

View file

@ -0,0 +1,38 @@
package moe.lava.neon.core.api.gateway
import co.touchlab.kermit.Logger
import dev.zacsweers.metro.Inject
import kotlinx.serialization.ExperimentalSerializationApi
import moe.lava.neon.core.di.EventHandlerGraph
import moe.lava.neon.core.repository.AuthRepository
@Inject
class GatewayHandler(
private val auth: AuthRepository,
private val handlers: EventHandlerGraph,
) {
private val logger = Logger.withTag("neon.core.api.gateway/handler")
private var session: GatewaySession? = null
@OptIn(ExperimentalSerializationApi::class)
suspend fun connect() {
if (session != null) {
logger.w(Throwable()) { "Attempted to connect, but client already connected, ignoring..." }
return
}
val token = auth.token
?: throw IllegalStateException("Tried to connect to gateway with no token")
session = GatewaySession.start(
token = token,
eventHandlers = handlers,
onDestroy = { session = null }
)
}
suspend fun disconnect() {
val session = session
?: throw IllegalStateException("Tried disconnecting with no session")
session.close()
}
}

View file

@ -1,7 +1,6 @@
package moe.lava.neon.core.api.gateway package moe.lava.neon.core.api.gateway
import co.touchlab.kermit.Logger import co.touchlab.kermit.Logger
import dev.zacsweers.metro.Inject
import io.ktor.client.HttpClient import io.ktor.client.HttpClient
import io.ktor.client.plugins.cookies.HttpCookies import io.ktor.client.plugins.cookies.HttpCookies
import io.ktor.client.plugins.websocket.DefaultClientWebSocketSession import io.ktor.client.plugins.websocket.DefaultClientWebSocketSession
@ -15,46 +14,44 @@ import io.ktor.websocket.readText
import io.ktor.websocket.send import io.ktor.websocket.send
import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.cancel
import kotlinx.coroutines.delay import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.consumeAsFlow import kotlinx.coroutines.flow.consumeAsFlow
import kotlinx.coroutines.flow.launchIn import kotlinx.coroutines.flow.launchIn
import kotlinx.coroutines.flow.onCompletion import kotlinx.coroutines.flow.onCompletion
import kotlinx.coroutines.flow.onEach import kotlinx.coroutines.flow.onEach
import kotlinx.coroutines.isActive
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
import kotlinx.serialization.ExperimentalSerializationApi
import moe.lava.neon.core.api.ApiConstants import moe.lava.neon.core.api.ApiConstants
import moe.lava.neon.core.di.GatewayHandlerGraph import moe.lava.neon.core.api.ApiConstants.json
import moe.lava.neon.core.repository.AuthRepository import moe.lava.neon.core.di.EventHandlerGraph
import kotlin.random.Random import kotlin.random.Random
import kotlin.time.Duration
import kotlin.time.Duration.Companion.milliseconds import kotlin.time.Duration.Companion.milliseconds
@Inject private val logger = Logger.withTag("neon.core.api.gateway/session")
class Gateway(
private val auth: AuthRepository, class GatewaySession private constructor(
private val handlers: GatewayHandlerGraph, private var ws: DefaultClientWebSocketSession,
private val token: String,
private val handlers: EventHandlerGraph,
private val scope: CoroutineScope,
private val onDestroy: () -> Unit,
) { ) {
private val logger = Logger.withTag("neon.core.api/gateway") private var lastSeq: Int? = null
private val scope = CoroutineScope(Dispatchers.IO) private var missedHeartbeats = 0
private var ws: DefaultClientWebSocketSession? = null
private val json = ApiConstants.json companion object {
suspend fun start(
private var seq: Int? = null token: String,
eventHandlers: EventHandlerGraph,
@OptIn(ExperimentalSerializationApi::class) client: HttpClient = HttpClient {
suspend fun connect() {
if (ws != null) {
logger.w(Throwable()) { "Attempted to connect, but client already connected, ignoring..." }
return
}
if (auth.token == null) {
throw IllegalStateException("Tried to connect to gateway with no token")
}
val ws = HttpClient {
install(HttpCookies) install(HttpCookies)
install(WebSockets) install(WebSockets)
}.webSocketSession("wss://gateway.discord.gg/") { },
scope: CoroutineScope = CoroutineScope(Dispatchers.IO),
onDestroy: () -> Unit,
): GatewaySession {
val ws = client.webSocketSession("wss://gateway.discord.gg/") {
userAgent(ApiConstants.gatewayUserAgent) userAgent(ApiConstants.gatewayUserAgent)
url { url {
parameter("encoding", "json") parameter("encoding", "json")
@ -62,10 +59,14 @@ class Gateway(
// parameter("compress", "zstd-stream") // parameter("compress", "zstd-stream")
} }
} }
this.ws = ws
return GatewaySession(ws, token, eventHandlers, scope, onDestroy)
}
}
init {
ws.incoming.consumeAsFlow() ws.incoming.consumeAsFlow()
.onCompletion { cleanup(it) } .onCompletion { close(it) }
.onEach { frame -> .onEach { frame ->
if (frame !is Frame.Text) if (frame !is Frame.Text)
// if (frame !is Frame.Text && frame !is Frame.Binary) // if (frame !is Frame.Text && frame !is Frame.Binary)
@ -74,9 +75,9 @@ class Gateway(
logger.d { "Received payload ${frame.readText()}" } logger.d { "Received payload ${frame.readText()}" }
val raw = json.decodeFromString<Payload.Unknown>(frame.readText()) val raw = json.decodeFromString<Payload.Unknown>(frame.readText())
val seq = this.seq ?: 0 val seq = this.lastSeq ?: 0
if (seq + 1 == raw.s) { if (seq + 1 == raw.s) {
this.seq = raw.s this.lastSeq = raw.s
} else if (raw.s != null) { } else if (raw.s != null) {
resume(ResumeReason.SkippedSequence(raw.s)) resume(ResumeReason.SkippedSequence(raw.s))
return@onEach return@onEach
@ -90,63 +91,37 @@ class Gateway(
.launchIn(scope) .launchIn(scope)
} }
suspend fun handlePayload(e: Payload.Incoming<*>) { private suspend fun handlePayload(payload: Payload.Incoming<*>) {
logger.d { e.toString() } logger.d { payload.toString() }
when (val event = e.d) { when (val event = payload.d) {
is Event.Hello -> handleHello(event) is Event.Hello -> handleHello(event)
is Event.Ready -> handlers.ready.handle(event) is Event.Ready -> handlers.ready.handle(event)
is Event.Heartbeat -> {} is Event.Heartbeat -> {}
is Event.HeartbeatAck -> { missedBeats -= 1 } is Event.HeartbeatAck -> { missedHeartbeats -= 1 }
} }
} }
suspend fun handleUnknownPayload(e: Payload.Unknown) { private suspend fun handleUnknownPayload(payload: Payload.Unknown) {
logger.w { "Unknown payload $e" } logger.w { "Unknown payload $payload" }
} }
suspend fun handleHello(e: Event.Hello) { private suspend fun handleHello(e: Event.Hello) {
val token = auth.token
?: throw IllegalStateException("Token missing between connection and hello, cannot send Identify")
Event.Identify(token = token).pack().send() Event.Identify(token = token).pack().send()
val interval = e.heartbeatInterval.milliseconds val interval = e.heartbeatInterval.milliseconds
scope.launch { scope.launch {
startHeartbeat(interval)
}
}
private var missedBeats = 0
private suspend fun startHeartbeat(interval: Duration) {
val ws = this.ws
?: throw IllegalStateException("Ws missing whilst starting heartbeat")
delay(interval * Random.nextDouble()) delay(interval * Random.nextDouble())
while (this@Gateway.ws == ws) { while (true) {
if (missedBeats >= 1) { if (missedHeartbeats >= 1) {
resume(ResumeReason.MissedHeartbeat) resume(ResumeReason.MissedHeartbeat)
break break
} }
Event.QoSHeartbeat(this@Gateway.seq).pack().send() Event.QoSHeartbeat(lastSeq).pack().send()
missedBeats += 1 missedHeartbeats += 1
delay(interval) delay(interval)
break
} }
} }
// TODO: handle resuming, etc..
suspend fun cleanup(error: Throwable? = null) {
logger.d(error) { "Websocket connection closed, cleaning up..." }
this.ws = null
}
suspend fun disconnect() {
val ws = ws
if (ws == null) {
logger.w(Throwable()) { "Attempted to disconnect, but client was not connected" }
return
}
this.ws = null
ws.close()
} }
private sealed class ResumeReason { private sealed class ResumeReason {
@ -160,7 +135,7 @@ class Gateway(
is ResumeReason.MissedHeartbeat -> is ResumeReason.MissedHeartbeat ->
"heartbeat missed" "heartbeat missed"
is ResumeReason.SkippedSequence -> is ResumeReason.SkippedSequence ->
"payloads skipped one sequence (expected: $seq, actual: ${reason.next})" "payloads skipped one sequence (expected: $lastSeq, actual: ${reason.next})"
is ResumeReason.CloseCode -> is ResumeReason.CloseCode ->
"closed with code ${reason.code}" "closed with code ${reason.code}"
null -> null ->
@ -171,9 +146,15 @@ class Gateway(
// TODO // TODO
} }
// TODO: handle resuming, etc..
suspend fun close(error: Throwable? = null) {
logger.d(error) { "Websocket connection closed, cleaning up..." }
ws.close()
if (scope.isActive) scope.cancel()
onDestroy()
}
private suspend inline fun <reified T : Event.Outgoing> Payload.Outgoing<T>.send() { private suspend inline fun <reified T : Event.Outgoing> Payload.Outgoing<T>.send() {
val ws = ws
?: throw IllegalStateException("Tried to send with no connection")
logger.d { "Sending payload $this" } logger.d { "Sending payload $this" }
logger.d { "Raw: ${json.encodeToString(this)}" } logger.d { "Raw: ${json.encodeToString(this)}" }
ws.send(json.encodeToString(this)) ws.send(json.encodeToString(this))

View file

@ -66,7 +66,7 @@ sealed interface Event {
// 11 // 11
@JvmInline @JvmInline
@Serializable @Serializable
value class HeartbeatAck(private val nothing: Nothing?) : Incoming, Outgoing value class HeartbeatAck(private val nothing: Nothing?) : Incoming
// 10 // 10
@Serializable @Serializable

View file

@ -14,5 +14,5 @@ interface AppGraph {
val auth: AuthRepository val auth: AuthRepository
val users: UserRepository val users: UserRepository
val gatewayHandlers: GatewayHandlerGraph val gatewayHandlers: EventHandlerGraph
} }

View file

@ -7,6 +7,6 @@ import moe.lava.neon.core.api.gateway.handlers.ReadyHandler
@GraphExtension @GraphExtension
@ContributesTo(AppScope::class) @ContributesTo(AppScope::class)
interface GatewayHandlerGraph { interface EventHandlerGraph {
val ready: ReadyHandler val ready: ReadyHandler
} }

View file

@ -25,7 +25,7 @@ import dev.zacsweers.metro.Inject
import dev.zacsweers.metrox.viewmodel.ViewModelKey import dev.zacsweers.metrox.viewmodel.ViewModelKey
import dev.zacsweers.metrox.viewmodel.metroViewModel import dev.zacsweers.metrox.viewmodel.metroViewModel
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
import moe.lava.neon.core.api.gateway.Gateway import moe.lava.neon.core.api.gateway.GatewayHandler
import moe.lava.neon.core.repository.AuthRepository import moe.lava.neon.core.repository.AuthRepository
import moe.lava.neon.resources.Res import moe.lava.neon.resources.Res
import moe.lava.neon.resources.compose_multiplatform import moe.lava.neon.resources.compose_multiplatform
@ -81,7 +81,7 @@ fun Sample(onRequestLogout: () -> Unit) {
@ContributesIntoMap(AppScope::class) @ContributesIntoMap(AppScope::class)
class SampleViewModel( class SampleViewModel(
private val auth: AuthRepository, private val auth: AuthRepository,
private val gateway: Gateway, private val gateway: GatewayHandler,
) : ViewModel() { ) : ViewModel() {
val token get() = auth.token val token get() = auth.token