feat(gateway): handle resume and opcode 9

This commit is contained in:
Cilly Leang 2026-01-27 21:00:15 +11:00
parent bed582b953
commit 22faef0fb0
Signed by: cilly
GPG key ID: 6500251E087653C9
9 changed files with 187 additions and 62 deletions

View file

@ -0,0 +1,19 @@
package moe.lava.neon.core.api.gateway
import io.ktor.websocket.CloseReason
sealed interface GatewayCloseReason {
sealed interface ClientInitiated : GatewayCloseReason
sealed class ShouldReconnect(val resume: Boolean) : GatewayCloseReason
sealed class KeepDisconnected : GatewayCloseReason
data object MissedHeartbeat : ShouldReconnect(resume = true), ClientInitiated
data class SkippedSequence(val next: Int) : ShouldReconnect(resume = true), ClientInitiated
data class InvalidSession(val resumable: Boolean) : ShouldReconnect(resume = resumable), ClientInitiated
// TODO: handle non-resumable cases properly
data class ServerClosed(val closeCode: CloseReason) : ShouldReconnect(resume = true)
data object ClientPaused : KeepDisconnected(), ClientInitiated
data object Unknown : ShouldReconnect(resume = true)
}

View file

@ -2,9 +2,15 @@ package moe.lava.neon.core.api.gateway
import co.touchlab.kermit.Logger import co.touchlab.kermit.Logger
import dev.zacsweers.metro.Inject import dev.zacsweers.metro.Inject
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.delay
import kotlinx.coroutines.launch
import kotlinx.serialization.ExperimentalSerializationApi import kotlinx.serialization.ExperimentalSerializationApi
import moe.lava.neon.core.di.EventHandlerGraph import moe.lava.neon.core.di.EventHandlerGraph
import moe.lava.neon.core.repository.AuthRepository import moe.lava.neon.core.repository.AuthRepository
import kotlin.math.pow
import kotlin.time.Duration.Companion.seconds
@Inject @Inject
class GatewayHandler( class GatewayHandler(
@ -12,7 +18,11 @@ class GatewayHandler(
private val handlers: EventHandlerGraph, private val handlers: EventHandlerGraph,
) { ) {
private val logger = Logger.withTag("neon.core.api.gateway/handler") private val logger = Logger.withTag("neon.core.api.gateway/handler")
private val scope = CoroutineScope(Dispatchers.IO)
private var session: GatewaySession? = null private var session: GatewaySession? = null
private var resumeProps: ResumeProperties? = null
private var retryAttempts: Int = 0
@OptIn(ExperimentalSerializationApi::class) @OptIn(ExperimentalSerializationApi::class)
suspend fun connect() { suspend fun connect() {
@ -26,13 +36,45 @@ class GatewayHandler(
session = GatewaySession.start( session = GatewaySession.start(
token = token, token = token,
eventHandlers = handlers, eventHandlers = handlers,
onDestroy = { session = null } resumeProps = resumeProps,
onSuccess = {
logger.d { "Successful session start" }
retryAttempts = 0
},
onDestroy = { reason, resumeProps ->
session = null
if (reason is GatewayCloseReason.KeepDisconnected) {
this.resumeProps = resumeProps
}
if (reason is GatewayCloseReason.ShouldReconnect) {
if (reason.resume) {
this.resumeProps = resumeProps
} else {
this.resumeProps = null
}
scope.launch {
var res: Result<Unit>
do {
val dur = 2.0.pow(retryAttempts).seconds
logger.d { "Reconnecting in ${dur.inWholeMilliseconds}ms" }
delay(dur)
retryAttempts += 1
res = runCatching { connect() }
res.exceptionOrNull()?.let {
logger.e(it) { "Reconnect failed" }
}
} while(res.isFailure)
}
}
}
) )
} }
suspend fun disconnect() { suspend fun disconnect() {
val session = session val session = session
?: throw IllegalStateException("Tried disconnecting with no session") ?: throw IllegalStateException("Tried disconnecting with no session")
session.close() session.close(GatewayCloseReason.ClientPaused)
} }
} }

View file

@ -9,11 +9,11 @@ import io.ktor.client.plugins.websocket.webSocketSession
import io.ktor.client.request.parameter import io.ktor.client.request.parameter
import io.ktor.http.userAgent import io.ktor.http.userAgent
import io.ktor.websocket.Frame import io.ktor.websocket.Frame
import io.ktor.websocket.close
import io.ktor.websocket.readText import io.ktor.websocket.readText
import io.ktor.websocket.send import io.ktor.websocket.send
import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.cancel import kotlinx.coroutines.cancel
import kotlinx.coroutines.delay import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.consumeAsFlow import kotlinx.coroutines.flow.consumeAsFlow
@ -35,10 +35,13 @@ class GatewaySession private constructor(
private val token: String, private val token: String,
private val handlers: EventHandlerGraph, private val handlers: EventHandlerGraph,
private val scope: CoroutineScope, private val scope: CoroutineScope,
private val onDestroy: () -> Unit, private var resumeProps: ResumeProperties?,
private val onDestroy: (GatewayCloseReason, ResumeProperties?) -> Unit,
private val onSuccess: () -> Unit,
) { ) {
private var lastSeq: Int? = null private var lastSeq: Int? = resumeProps?.lastSequence
private var missedHeartbeats = 0 private var missedHeartbeats = 0
private var closeReason: GatewayCloseReason? = null
companion object { companion object {
suspend fun start( suspend fun start(
@ -49,9 +52,13 @@ class GatewaySession private constructor(
install(WebSockets) install(WebSockets)
}, },
scope: CoroutineScope = CoroutineScope(Dispatchers.IO), scope: CoroutineScope = CoroutineScope(Dispatchers.IO),
onDestroy: () -> Unit, resumeProps: ResumeProperties? = null,
onDestroy: (GatewayCloseReason, ResumeProperties?) -> Unit,
onSuccess: () -> Unit,
): GatewaySession { ): GatewaySession {
val ws = client.webSocketSession("wss://gateway.discord.gg/") { val ws = client.webSocketSession(
resumeProps?.resumeGatewayUrl ?: "wss://gateway.discord.gg/"
) {
userAgent(ApiConstants.gatewayUserAgent) userAgent(ApiConstants.gatewayUserAgent)
url { url {
parameter("encoding", "json") parameter("encoding", "json")
@ -60,13 +67,13 @@ class GatewaySession private constructor(
} }
} }
return GatewaySession(ws, token, eventHandlers, scope, onDestroy) return GatewaySession(ws, token, eventHandlers, scope, resumeProps, onDestroy, onSuccess)
} }
} }
init { init {
ws.incoming.consumeAsFlow() ws.incoming.consumeAsFlow()
.onCompletion { close(it) } .onCompletion { onClose(it) }
.onEach { frame -> .onEach { frame ->
if (frame !is Frame.Text) if (frame !is Frame.Text)
// if (frame !is Frame.Text && frame !is Frame.Binary) // if (frame !is Frame.Text && frame !is Frame.Binary)
@ -79,7 +86,7 @@ class GatewaySession private constructor(
if (seq + 1 == raw.s) { if (seq + 1 == raw.s) {
this.lastSeq = raw.s this.lastSeq = raw.s
} else if (raw.s != null) { } else if (raw.s != null) {
resume(ResumeReason.SkippedSequence(raw.s)) close(GatewayCloseReason.SkippedSequence(raw.s))
return@onEach return@onEach
} }
@ -95,8 +102,13 @@ class GatewaySession private constructor(
logger.d { payload.toString() } logger.d { payload.toString() }
when (val event = payload.d) { when (val event = payload.d) {
is Event.Hello -> handleHello(event) is Event.Hello -> handleHello(event)
is Event.Ready -> handlers.ready.handle(event) is Event.Ready -> handlers.ready.handle(event) {
resumeProps = it
onSuccess()
}
is Event.Resumed -> onSuccess()
is Event.Heartbeat -> handleHeartbeat() is Event.Heartbeat -> handleHeartbeat()
is Event.InvalidSession -> close(GatewayCloseReason.InvalidSession(event.resumable))
is Event.HeartbeatAck -> { missedHeartbeats -= 1 } is Event.HeartbeatAck -> { missedHeartbeats -= 1 }
} }
} }
@ -112,14 +124,23 @@ class GatewaySession private constructor(
} }
private suspend fun handleHello(e: Event.Hello) { private suspend fun handleHello(e: Event.Hello) {
Event.Identify(token = token).pack().send() val resumeProps = resumeProps
if (resumeProps != null) {
Event.Resume(
token = token,
sessionId = resumeProps.sessionId,
seq = resumeProps.lastSequence
).pack().send()
} else {
Event.Identify(token = token).pack().send()
}
val interval = e.heartbeatInterval.milliseconds val interval = e.heartbeatInterval.milliseconds
scope.launch { scope.launch {
delay(interval * Random.nextDouble()) delay(interval * Random.nextDouble())
while (true) { while (true) {
if (missedHeartbeats >= 1) { if (missedHeartbeats >= 1) {
resume(ResumeReason.MissedHeartbeat) close(GatewayCloseReason.MissedHeartbeat)
break break
} }
Event.QoSHeartbeat(lastSeq).pack().send() Event.QoSHeartbeat(lastSeq).pack().send()
@ -129,34 +150,40 @@ class GatewaySession private constructor(
} }
} }
private sealed class ResumeReason { fun close(reason: GatewayCloseReason.ClientInitiated?) {
data object MissedHeartbeat : ResumeReason()
data class SkippedSequence(val next: Int) : ResumeReason()
data class CloseCode(val code: Int) : ResumeReason()
}
private suspend fun resume(reason: ResumeReason?) {
val msg = when (reason) { val msg = when (reason) {
is ResumeReason.MissedHeartbeat -> is GatewayCloseReason.MissedHeartbeat ->
"heartbeat missed" "heartbeat missed"
is ResumeReason.SkippedSequence -> is GatewayCloseReason.SkippedSequence ->
"payloads skipped one sequence (expected: $lastSeq, actual: ${reason.next})" "payloads skipped one sequence (expected: $lastSeq, actual: ${reason.next})"
is ResumeReason.CloseCode -> is GatewayCloseReason.InvalidSession ->
"closed with code ${reason.code}" "invalid session (resumable: $reason)"
is GatewayCloseReason.ClientPaused ->
"client requested pause"
null -> null ->
"no reason" "no reason"
} }
closeReason = reason
logger.e { "Resuming, cause: $msg" } logger.e { "Client-initiated close, cause: $msg" }
// TODO ws.cancel()
} }
// TODO: handle resuming, etc.. @OptIn(ExperimentalCoroutinesApi::class)
suspend fun close(error: Throwable? = null) { private fun onClose(error: Throwable? = null) {
logger.d(error) { "Websocket connection closed, cleaning up..." } logger.d(error) { "Websocket connection closed, cleaning up..." }
ws.close()
if (scope.isActive) scope.cancel() if (scope.isActive) scope.cancel()
onDestroy() if (resumeProps == null) {
logger.w { "No resume props stored" }
}
onDestroy(
closeReason
?: runCatching { ws.closeReason.getCompleted() }
.getOrNull()
?.let { GatewayCloseReason.ServerClosed(it) }
?: GatewayCloseReason.Unknown,
resumeProps?.copy(lastSequence = lastSeq ?: 0)
)
} }
private suspend inline fun <reified T : Event.Outgoing> Payload.Outgoing<T>.send() { private suspend inline fun <reified T : Event.Outgoing> Payload.Outgoing<T>.send() {

View file

@ -1,6 +1,5 @@
package moe.lava.neon.core.api.gateway package moe.lava.neon.core.api.gateway
import kotlinx.serialization.ExperimentalSerializationApi
import kotlinx.serialization.Serializable import kotlinx.serialization.Serializable
import kotlinx.serialization.json.JsonElement import kotlinx.serialization.json.JsonElement
import moe.lava.neon.core.api.ApiConstants import moe.lava.neon.core.api.ApiConstants
@ -49,32 +48,8 @@ sealed interface Event {
@Serializable @Serializable
value class Heartbeat(val lastSequence: Int?) : Incoming, Outgoing value class Heartbeat(val lastSequence: Int?) : Incoming, Outgoing
// 40
@Serializable
data class QoSHeartbeat(
val seq: Int?,
val qos: QoSPayload = QoSPayload(),
) : Outgoing {
@Serializable
data class QoSPayload(
val ver: Int = 27,
val active: Boolean = true,
val reasons: List<String> = listOf("foregrounded"),
)
}
// 11
@JvmInline
@Serializable
value class HeartbeatAck(private val nothing: Nothing?) : Incoming
// 10
@Serializable
data class Hello(val heartbeatInterval: Int) : Incoming
// 2 // 2
@Serializable @Serializable
@OptIn(ExperimentalSerializationApi::class)
data class Identify( data class Identify(
val token: String, val token: String,
val properties: ApiConstants.GatewayProperties = ApiConstants.GatewayProperties(), val properties: ApiConstants.GatewayProperties = ApiConstants.GatewayProperties(),
@ -89,6 +64,42 @@ sealed interface Event {
// val clientState: ClientState, // val clientState: ClientState,
) : Outgoing ) : Outgoing
// 6
@Serializable
data class Resume(
val token: String,
val sessionId: String,
val seq: Int,
) : Outgoing
// 9
@JvmInline
@Serializable
value class InvalidSession(val resumable: Boolean) : Incoming
// 10
@Serializable
data class Hello(val heartbeatInterval: Int) : Incoming
// 11
@JvmInline
@Serializable
value class HeartbeatAck(private val nothing: Nothing?) : Incoming
// 40
@Serializable
data class QoSHeartbeat(
val seq: Int?,
val qos: QoSPayload = QoSPayload(),
) : Outgoing {
@Serializable
data class QoSPayload(
val ver: Int = 27,
val active: Boolean = true,
val reasons: List<String> = listOf("foregrounded"),
)
}
@Serializable @Serializable
data class Ready( data class Ready(
val v: Int, val v: Int,
@ -98,4 +109,7 @@ sealed interface Event {
val resumeGatewayUrl: String, val resumeGatewayUrl: String,
// val application: Application, // val application: Application,
) : Dispatch() ) : Dispatch()
@Serializable
data object Resumed : Dispatch()
} }

View file

@ -0,0 +1,7 @@
package moe.lava.neon.core.api.gateway
data class ResumeProperties(
val sessionId: String,
val resumeGatewayUrl: String,
val lastSequence: Int,
)

View file

@ -11,7 +11,7 @@ fun <T : Event.Outgoing> T.pack(): Payload.Outgoing<T> {
val opcode: Int = when (this) { val opcode: Int = when (this) {
is Event.Heartbeat -> 1 is Event.Heartbeat -> 1
is Event.Identify -> 2 is Event.Identify -> 2
is Event.HeartbeatAck -> 11 is Event.Resume -> 6
is Event.QoSHeartbeat -> 40 is Event.QoSHeartbeat -> 40
} }
return Payload.Outgoing(op = opcode, d = this) return Payload.Outgoing(op = opcode, d = this)
@ -21,9 +21,11 @@ fun Payload.Unknown.asIncoming() : Payload.WithSequence {
return when (op) { return when (op) {
0 -> when (t) { 0 -> when (t) {
"READY" -> decode<Event.Ready>() "READY" -> decode<Event.Ready>()
"RESUMED" -> decode<Event.Resumed>()
else -> this else -> this
} }
1 -> decode<Event.Heartbeat>() 1 -> decode<Event.Heartbeat>()
9 -> decode<Event.InvalidSession>()
10 -> decode<Event.Hello>() 10 -> decode<Event.Hello>()
11 -> decode<Event.HeartbeatAck>() 11 -> decode<Event.HeartbeatAck>()
else -> this else -> this

View file

@ -2,6 +2,4 @@ package moe.lava.neon.core.api.gateway.handlers
import moe.lava.neon.core.api.gateway.Event import moe.lava.neon.core.api.gateway.Event
sealed interface Handler<T: Event.Incoming> { sealed interface Handler<T: Event.Incoming>
fun handle(event: T)
}

View file

@ -3,12 +3,18 @@ package moe.lava.neon.core.api.gateway.handlers
import co.touchlab.kermit.Logger import co.touchlab.kermit.Logger
import dev.zacsweers.metro.Inject import dev.zacsweers.metro.Inject
import moe.lava.neon.core.api.gateway.Event import moe.lava.neon.core.api.gateway.Event
import moe.lava.neon.core.api.gateway.ResumeProperties
private val logger = Logger.withTag("neon.core.api.events/ready") private val logger = Logger.withTag("neon.core.api.events/ready")
@Inject @Inject
class ReadyHandler : Handler<Event.Ready> { class ReadyHandler : Handler<Event.Ready> {
override fun handle(event: Event.Ready) { fun handle(event: Event.Ready, updateResumeProps: (ResumeProperties) -> Unit) {
logger.i { "Received payload $event" } logger.i { "Received payload $event" }
updateResumeProps(ResumeProperties(
sessionId = event.sessionId,
resumeGatewayUrl = event.resumeGatewayUrl,
lastSequence = 0,
))
} }
} }

View file

@ -19,6 +19,7 @@ import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier import androidx.compose.ui.Modifier
import androidx.lifecycle.ViewModel import androidx.lifecycle.ViewModel
import androidx.lifecycle.viewModelScope import androidx.lifecycle.viewModelScope
import co.touchlab.kermit.Logger
import dev.zacsweers.metro.AppScope import dev.zacsweers.metro.AppScope
import dev.zacsweers.metro.ContributesIntoMap import dev.zacsweers.metro.ContributesIntoMap
import dev.zacsweers.metro.Inject import dev.zacsweers.metro.Inject
@ -83,16 +84,25 @@ class SampleViewModel(
private val auth: AuthRepository, private val auth: AuthRepository,
private val gateway: GatewayHandler, private val gateway: GatewayHandler,
) : ViewModel() { ) : ViewModel() {
private val logger = Logger.withTag("neon.ui.screens/Sample")
val token get() = auth.token val token get() = auth.token
fun connect() { fun connect() {
viewModelScope.launch { viewModelScope.launch {
gateway.connect() try {
gateway.connect()
} catch(e: Throwable) {
logger.e(e) { "Failed to connect to gateway: ${e.stackTraceToString()}" }
}
} }
} }
fun disconnect() { fun disconnect() {
viewModelScope.launch { viewModelScope.launch {
gateway.disconnect() try {
gateway.disconnect()
} catch(e: Throwable) {
logger.e(e) { "Failed to connect to gateway: ${e.stackTraceToString()}" }
}
} }
} }
fun logout() { fun logout() {