refactor: split Gateway into Handler and Session

This commit is contained in:
Cilly Leang 2026-01-27 17:37:25 +11:00
parent 214efd7417
commit b04d75df99
Signed by: cilly
GPG key ID: 6500251E087653C9
6 changed files with 108 additions and 89 deletions

View file

@ -0,0 +1,38 @@
package moe.lava.neon.core.api.gateway
import co.touchlab.kermit.Logger
import dev.zacsweers.metro.Inject
import kotlinx.serialization.ExperimentalSerializationApi
import moe.lava.neon.core.di.EventHandlerGraph
import moe.lava.neon.core.repository.AuthRepository
@Inject
class GatewayHandler(
private val auth: AuthRepository,
private val handlers: EventHandlerGraph,
) {
private val logger = Logger.withTag("neon.core.api.gateway/handler")
private var session: GatewaySession? = null
@OptIn(ExperimentalSerializationApi::class)
suspend fun connect() {
if (session != null) {
logger.w(Throwable()) { "Attempted to connect, but client already connected, ignoring..." }
return
}
val token = auth.token
?: throw IllegalStateException("Tried to connect to gateway with no token")
session = GatewaySession.start(
token = token,
eventHandlers = handlers,
onDestroy = { session = null }
)
}
suspend fun disconnect() {
val session = session
?: throw IllegalStateException("Tried disconnecting with no session")
session.close()
}
}

View file

@ -1,7 +1,6 @@
package moe.lava.neon.core.api.gateway
import co.touchlab.kermit.Logger
import dev.zacsweers.metro.Inject
import io.ktor.client.HttpClient
import io.ktor.client.plugins.cookies.HttpCookies
import io.ktor.client.plugins.websocket.DefaultClientWebSocketSession
@ -15,57 +14,59 @@ import io.ktor.websocket.readText
import io.ktor.websocket.send
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.cancel
import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.consumeAsFlow
import kotlinx.coroutines.flow.launchIn
import kotlinx.coroutines.flow.onCompletion
import kotlinx.coroutines.flow.onEach
import kotlinx.coroutines.isActive
import kotlinx.coroutines.launch
import kotlinx.serialization.ExperimentalSerializationApi
import moe.lava.neon.core.api.ApiConstants
import moe.lava.neon.core.di.GatewayHandlerGraph
import moe.lava.neon.core.repository.AuthRepository
import moe.lava.neon.core.api.ApiConstants.json
import moe.lava.neon.core.di.EventHandlerGraph
import kotlin.random.Random
import kotlin.time.Duration
import kotlin.time.Duration.Companion.milliseconds
@Inject
class Gateway(
private val auth: AuthRepository,
private val handlers: GatewayHandlerGraph,
private val logger = Logger.withTag("neon.core.api.gateway/session")
class GatewaySession private constructor(
private var ws: DefaultClientWebSocketSession,
private val token: String,
private val handlers: EventHandlerGraph,
private val scope: CoroutineScope,
private val onDestroy: () -> Unit,
) {
private val logger = Logger.withTag("neon.core.api/gateway")
private val scope = CoroutineScope(Dispatchers.IO)
private var ws: DefaultClientWebSocketSession? = null
private var lastSeq: Int? = null
private var missedHeartbeats = 0
private val json = ApiConstants.json
private var seq: Int? = null
@OptIn(ExperimentalSerializationApi::class)
suspend fun connect() {
if (ws != null) {
logger.w(Throwable()) { "Attempted to connect, but client already connected, ignoring..." }
return
}
if (auth.token == null) {
throw IllegalStateException("Tried to connect to gateway with no token")
}
val ws = HttpClient {
install(HttpCookies)
install(WebSockets)
}.webSocketSession("wss://gateway.discord.gg/") {
userAgent(ApiConstants.gatewayUserAgent)
url {
parameter("encoding", "json")
parameter("v", "9")
// parameter("compress", "zstd-stream")
companion object {
suspend fun start(
token: String,
eventHandlers: EventHandlerGraph,
client: HttpClient = HttpClient {
install(HttpCookies)
install(WebSockets)
},
scope: CoroutineScope = CoroutineScope(Dispatchers.IO),
onDestroy: () -> Unit,
): GatewaySession {
val ws = client.webSocketSession("wss://gateway.discord.gg/") {
userAgent(ApiConstants.gatewayUserAgent)
url {
parameter("encoding", "json")
parameter("v", "9")
// parameter("compress", "zstd-stream")
}
}
}
this.ws = ws
return GatewaySession(ws, token, eventHandlers, scope, onDestroy)
}
}
init {
ws.incoming.consumeAsFlow()
.onCompletion { cleanup(it) }
.onCompletion { close(it) }
.onEach { frame ->
if (frame !is Frame.Text)
// if (frame !is Frame.Text && frame !is Frame.Binary)
@ -74,9 +75,9 @@ class Gateway(
logger.d { "Received payload ${frame.readText()}" }
val raw = json.decodeFromString<Payload.Unknown>(frame.readText())
val seq = this.seq ?: 0
val seq = this.lastSeq ?: 0
if (seq + 1 == raw.s) {
this.seq = raw.s
this.lastSeq = raw.s
} else if (raw.s != null) {
resume(ResumeReason.SkippedSequence(raw.s))
return@onEach
@ -90,65 +91,39 @@ class Gateway(
.launchIn(scope)
}
suspend fun handlePayload(e: Payload.Incoming<*>) {
logger.d { e.toString() }
when (val event = e.d) {
private suspend fun handlePayload(payload: Payload.Incoming<*>) {
logger.d { payload.toString() }
when (val event = payload.d) {
is Event.Hello -> handleHello(event)
is Event.Ready -> handlers.ready.handle(event)
is Event.Heartbeat -> {}
is Event.HeartbeatAck -> { missedBeats -= 1 }
is Event.HeartbeatAck -> { missedHeartbeats -= 1 }
}
}
suspend fun handleUnknownPayload(e: Payload.Unknown) {
logger.w { "Unknown payload $e" }
private suspend fun handleUnknownPayload(payload: Payload.Unknown) {
logger.w { "Unknown payload $payload" }
}
suspend fun handleHello(e: Event.Hello) {
val token = auth.token
?: throw IllegalStateException("Token missing between connection and hello, cannot send Identify")
private suspend fun handleHello(e: Event.Hello) {
Event.Identify(token = token).pack().send()
val interval = e.heartbeatInterval.milliseconds
scope.launch {
startHeartbeat(interval)
}
}
private var missedBeats = 0
private suspend fun startHeartbeat(interval: Duration) {
val ws = this.ws
?: throw IllegalStateException("Ws missing whilst starting heartbeat")
delay(interval * Random.nextDouble())
while (this@Gateway.ws == ws) {
if (missedBeats >= 1) {
resume(ResumeReason.MissedHeartbeat)
delay(interval * Random.nextDouble())
while (true) {
if (missedHeartbeats >= 1) {
resume(ResumeReason.MissedHeartbeat)
break
}
Event.QoSHeartbeat(lastSeq).pack().send()
missedHeartbeats += 1
delay(interval)
break
}
Event.QoSHeartbeat(this@Gateway.seq).pack().send()
missedBeats += 1
delay(interval)
}
}
// TODO: handle resuming, etc..
suspend fun cleanup(error: Throwable? = null) {
logger.d(error) { "Websocket connection closed, cleaning up..." }
this.ws = null
}
suspend fun disconnect() {
val ws = ws
if (ws == null) {
logger.w(Throwable()) { "Attempted to disconnect, but client was not connected" }
return
}
this.ws = null
ws.close()
}
private sealed class ResumeReason {
data object MissedHeartbeat : ResumeReason()
data class SkippedSequence(val next: Int) : ResumeReason()
@ -160,7 +135,7 @@ class Gateway(
is ResumeReason.MissedHeartbeat ->
"heartbeat missed"
is ResumeReason.SkippedSequence ->
"payloads skipped one sequence (expected: $seq, actual: ${reason.next})"
"payloads skipped one sequence (expected: $lastSeq, actual: ${reason.next})"
is ResumeReason.CloseCode ->
"closed with code ${reason.code}"
null ->
@ -171,9 +146,15 @@ class Gateway(
// TODO
}
// TODO: handle resuming, etc..
suspend fun close(error: Throwable? = null) {
logger.d(error) { "Websocket connection closed, cleaning up..." }
ws.close()
if (scope.isActive) scope.cancel()
onDestroy()
}
private suspend inline fun <reified T : Event.Outgoing> Payload.Outgoing<T>.send() {
val ws = ws
?: throw IllegalStateException("Tried to send with no connection")
logger.d { "Sending payload $this" }
logger.d { "Raw: ${json.encodeToString(this)}" }
ws.send(json.encodeToString(this))

View file

@ -66,7 +66,7 @@ sealed interface Event {
// 11
@JvmInline
@Serializable
value class HeartbeatAck(private val nothing: Nothing?) : Incoming, Outgoing
value class HeartbeatAck(private val nothing: Nothing?) : Incoming
// 10
@Serializable

View file

@ -14,5 +14,5 @@ interface AppGraph {
val auth: AuthRepository
val users: UserRepository
val gatewayHandlers: GatewayHandlerGraph
val gatewayHandlers: EventHandlerGraph
}

View file

@ -7,6 +7,6 @@ import moe.lava.neon.core.api.gateway.handlers.ReadyHandler
@GraphExtension
@ContributesTo(AppScope::class)
interface GatewayHandlerGraph {
interface EventHandlerGraph {
val ready: ReadyHandler
}

View file

@ -25,7 +25,7 @@ import dev.zacsweers.metro.Inject
import dev.zacsweers.metrox.viewmodel.ViewModelKey
import dev.zacsweers.metrox.viewmodel.metroViewModel
import kotlinx.coroutines.launch
import moe.lava.neon.core.api.gateway.Gateway
import moe.lava.neon.core.api.gateway.GatewayHandler
import moe.lava.neon.core.repository.AuthRepository
import moe.lava.neon.resources.Res
import moe.lava.neon.resources.compose_multiplatform
@ -81,7 +81,7 @@ fun Sample(onRequestLogout: () -> Unit) {
@ContributesIntoMap(AppScope::class)
class SampleViewModel(
private val auth: AuthRepository,
private val gateway: Gateway,
private val gateway: GatewayHandler,
) : ViewModel() {
val token get() = auth.token