feat(gateway): handle resume and opcode 9

This commit is contained in:
Cilly Leang 2026-01-27 21:00:15 +11:00
parent bed582b953
commit 22faef0fb0
Signed by: cilly
GPG key ID: 6500251E087653C9
9 changed files with 187 additions and 62 deletions

View file

@ -0,0 +1,19 @@
package moe.lava.neon.core.api.gateway
import io.ktor.websocket.CloseReason
sealed interface GatewayCloseReason {
sealed interface ClientInitiated : GatewayCloseReason
sealed class ShouldReconnect(val resume: Boolean) : GatewayCloseReason
sealed class KeepDisconnected : GatewayCloseReason
data object MissedHeartbeat : ShouldReconnect(resume = true), ClientInitiated
data class SkippedSequence(val next: Int) : ShouldReconnect(resume = true), ClientInitiated
data class InvalidSession(val resumable: Boolean) : ShouldReconnect(resume = resumable), ClientInitiated
// TODO: handle non-resumable cases properly
data class ServerClosed(val closeCode: CloseReason) : ShouldReconnect(resume = true)
data object ClientPaused : KeepDisconnected(), ClientInitiated
data object Unknown : ShouldReconnect(resume = true)
}

View file

@ -2,9 +2,15 @@ package moe.lava.neon.core.api.gateway
import co.touchlab.kermit.Logger
import dev.zacsweers.metro.Inject
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.delay
import kotlinx.coroutines.launch
import kotlinx.serialization.ExperimentalSerializationApi
import moe.lava.neon.core.di.EventHandlerGraph
import moe.lava.neon.core.repository.AuthRepository
import kotlin.math.pow
import kotlin.time.Duration.Companion.seconds
@Inject
class GatewayHandler(
@ -12,7 +18,11 @@ class GatewayHandler(
private val handlers: EventHandlerGraph,
) {
private val logger = Logger.withTag("neon.core.api.gateway/handler")
private val scope = CoroutineScope(Dispatchers.IO)
private var session: GatewaySession? = null
private var resumeProps: ResumeProperties? = null
private var retryAttempts: Int = 0
@OptIn(ExperimentalSerializationApi::class)
suspend fun connect() {
@ -26,13 +36,45 @@ class GatewayHandler(
session = GatewaySession.start(
token = token,
eventHandlers = handlers,
onDestroy = { session = null }
resumeProps = resumeProps,
onSuccess = {
logger.d { "Successful session start" }
retryAttempts = 0
},
onDestroy = { reason, resumeProps ->
session = null
if (reason is GatewayCloseReason.KeepDisconnected) {
this.resumeProps = resumeProps
}
if (reason is GatewayCloseReason.ShouldReconnect) {
if (reason.resume) {
this.resumeProps = resumeProps
} else {
this.resumeProps = null
}
scope.launch {
var res: Result<Unit>
do {
val dur = 2.0.pow(retryAttempts).seconds
logger.d { "Reconnecting in ${dur.inWholeMilliseconds}ms" }
delay(dur)
retryAttempts += 1
res = runCatching { connect() }
res.exceptionOrNull()?.let {
logger.e(it) { "Reconnect failed" }
}
} while(res.isFailure)
}
}
}
)
}
suspend fun disconnect() {
val session = session
?: throw IllegalStateException("Tried disconnecting with no session")
session.close()
session.close(GatewayCloseReason.ClientPaused)
}
}

View file

@ -9,11 +9,11 @@ import io.ktor.client.plugins.websocket.webSocketSession
import io.ktor.client.request.parameter
import io.ktor.http.userAgent
import io.ktor.websocket.Frame
import io.ktor.websocket.close
import io.ktor.websocket.readText
import io.ktor.websocket.send
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.cancel
import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.consumeAsFlow
@ -35,10 +35,13 @@ class GatewaySession private constructor(
private val token: String,
private val handlers: EventHandlerGraph,
private val scope: CoroutineScope,
private val onDestroy: () -> Unit,
private var resumeProps: ResumeProperties?,
private val onDestroy: (GatewayCloseReason, ResumeProperties?) -> Unit,
private val onSuccess: () -> Unit,
) {
private var lastSeq: Int? = null
private var lastSeq: Int? = resumeProps?.lastSequence
private var missedHeartbeats = 0
private var closeReason: GatewayCloseReason? = null
companion object {
suspend fun start(
@ -49,9 +52,13 @@ class GatewaySession private constructor(
install(WebSockets)
},
scope: CoroutineScope = CoroutineScope(Dispatchers.IO),
onDestroy: () -> Unit,
resumeProps: ResumeProperties? = null,
onDestroy: (GatewayCloseReason, ResumeProperties?) -> Unit,
onSuccess: () -> Unit,
): GatewaySession {
val ws = client.webSocketSession("wss://gateway.discord.gg/") {
val ws = client.webSocketSession(
resumeProps?.resumeGatewayUrl ?: "wss://gateway.discord.gg/"
) {
userAgent(ApiConstants.gatewayUserAgent)
url {
parameter("encoding", "json")
@ -60,13 +67,13 @@ class GatewaySession private constructor(
}
}
return GatewaySession(ws, token, eventHandlers, scope, onDestroy)
return GatewaySession(ws, token, eventHandlers, scope, resumeProps, onDestroy, onSuccess)
}
}
init {
ws.incoming.consumeAsFlow()
.onCompletion { close(it) }
.onCompletion { onClose(it) }
.onEach { frame ->
if (frame !is Frame.Text)
// if (frame !is Frame.Text && frame !is Frame.Binary)
@ -79,7 +86,7 @@ class GatewaySession private constructor(
if (seq + 1 == raw.s) {
this.lastSeq = raw.s
} else if (raw.s != null) {
resume(ResumeReason.SkippedSequence(raw.s))
close(GatewayCloseReason.SkippedSequence(raw.s))
return@onEach
}
@ -95,8 +102,13 @@ class GatewaySession private constructor(
logger.d { payload.toString() }
when (val event = payload.d) {
is Event.Hello -> handleHello(event)
is Event.Ready -> handlers.ready.handle(event)
is Event.Ready -> handlers.ready.handle(event) {
resumeProps = it
onSuccess()
}
is Event.Resumed -> onSuccess()
is Event.Heartbeat -> handleHeartbeat()
is Event.InvalidSession -> close(GatewayCloseReason.InvalidSession(event.resumable))
is Event.HeartbeatAck -> { missedHeartbeats -= 1 }
}
}
@ -112,14 +124,23 @@ class GatewaySession private constructor(
}
private suspend fun handleHello(e: Event.Hello) {
val resumeProps = resumeProps
if (resumeProps != null) {
Event.Resume(
token = token,
sessionId = resumeProps.sessionId,
seq = resumeProps.lastSequence
).pack().send()
} else {
Event.Identify(token = token).pack().send()
}
val interval = e.heartbeatInterval.milliseconds
scope.launch {
delay(interval * Random.nextDouble())
while (true) {
if (missedHeartbeats >= 1) {
resume(ResumeReason.MissedHeartbeat)
close(GatewayCloseReason.MissedHeartbeat)
break
}
Event.QoSHeartbeat(lastSeq).pack().send()
@ -129,34 +150,40 @@ class GatewaySession private constructor(
}
}
private sealed class ResumeReason {
data object MissedHeartbeat : ResumeReason()
data class SkippedSequence(val next: Int) : ResumeReason()
data class CloseCode(val code: Int) : ResumeReason()
}
private suspend fun resume(reason: ResumeReason?) {
fun close(reason: GatewayCloseReason.ClientInitiated?) {
val msg = when (reason) {
is ResumeReason.MissedHeartbeat ->
is GatewayCloseReason.MissedHeartbeat ->
"heartbeat missed"
is ResumeReason.SkippedSequence ->
is GatewayCloseReason.SkippedSequence ->
"payloads skipped one sequence (expected: $lastSeq, actual: ${reason.next})"
is ResumeReason.CloseCode ->
"closed with code ${reason.code}"
is GatewayCloseReason.InvalidSession ->
"invalid session (resumable: $reason)"
is GatewayCloseReason.ClientPaused ->
"client requested pause"
null ->
"no reason"
}
closeReason = reason
logger.e { "Resuming, cause: $msg" }
// TODO
logger.e { "Client-initiated close, cause: $msg" }
ws.cancel()
}
// TODO: handle resuming, etc..
suspend fun close(error: Throwable? = null) {
@OptIn(ExperimentalCoroutinesApi::class)
private fun onClose(error: Throwable? = null) {
logger.d(error) { "Websocket connection closed, cleaning up..." }
ws.close()
if (scope.isActive) scope.cancel()
onDestroy()
if (resumeProps == null) {
logger.w { "No resume props stored" }
}
onDestroy(
closeReason
?: runCatching { ws.closeReason.getCompleted() }
.getOrNull()
?.let { GatewayCloseReason.ServerClosed(it) }
?: GatewayCloseReason.Unknown,
resumeProps?.copy(lastSequence = lastSeq ?: 0)
)
}
private suspend inline fun <reified T : Event.Outgoing> Payload.Outgoing<T>.send() {

View file

@ -1,6 +1,5 @@
package moe.lava.neon.core.api.gateway
import kotlinx.serialization.ExperimentalSerializationApi
import kotlinx.serialization.Serializable
import kotlinx.serialization.json.JsonElement
import moe.lava.neon.core.api.ApiConstants
@ -49,32 +48,8 @@ sealed interface Event {
@Serializable
value class Heartbeat(val lastSequence: Int?) : Incoming, Outgoing
// 40
@Serializable
data class QoSHeartbeat(
val seq: Int?,
val qos: QoSPayload = QoSPayload(),
) : Outgoing {
@Serializable
data class QoSPayload(
val ver: Int = 27,
val active: Boolean = true,
val reasons: List<String> = listOf("foregrounded"),
)
}
// 11
@JvmInline
@Serializable
value class HeartbeatAck(private val nothing: Nothing?) : Incoming
// 10
@Serializable
data class Hello(val heartbeatInterval: Int) : Incoming
// 2
@Serializable
@OptIn(ExperimentalSerializationApi::class)
data class Identify(
val token: String,
val properties: ApiConstants.GatewayProperties = ApiConstants.GatewayProperties(),
@ -89,6 +64,42 @@ sealed interface Event {
// val clientState: ClientState,
) : Outgoing
// 6
@Serializable
data class Resume(
val token: String,
val sessionId: String,
val seq: Int,
) : Outgoing
// 9
@JvmInline
@Serializable
value class InvalidSession(val resumable: Boolean) : Incoming
// 10
@Serializable
data class Hello(val heartbeatInterval: Int) : Incoming
// 11
@JvmInline
@Serializable
value class HeartbeatAck(private val nothing: Nothing?) : Incoming
// 40
@Serializable
data class QoSHeartbeat(
val seq: Int?,
val qos: QoSPayload = QoSPayload(),
) : Outgoing {
@Serializable
data class QoSPayload(
val ver: Int = 27,
val active: Boolean = true,
val reasons: List<String> = listOf("foregrounded"),
)
}
@Serializable
data class Ready(
val v: Int,
@ -98,4 +109,7 @@ sealed interface Event {
val resumeGatewayUrl: String,
// val application: Application,
) : Dispatch()
@Serializable
data object Resumed : Dispatch()
}

View file

@ -0,0 +1,7 @@
package moe.lava.neon.core.api.gateway
data class ResumeProperties(
val sessionId: String,
val resumeGatewayUrl: String,
val lastSequence: Int,
)

View file

@ -11,7 +11,7 @@ fun <T : Event.Outgoing> T.pack(): Payload.Outgoing<T> {
val opcode: Int = when (this) {
is Event.Heartbeat -> 1
is Event.Identify -> 2
is Event.HeartbeatAck -> 11
is Event.Resume -> 6
is Event.QoSHeartbeat -> 40
}
return Payload.Outgoing(op = opcode, d = this)
@ -21,9 +21,11 @@ fun Payload.Unknown.asIncoming() : Payload.WithSequence {
return when (op) {
0 -> when (t) {
"READY" -> decode<Event.Ready>()
"RESUMED" -> decode<Event.Resumed>()
else -> this
}
1 -> decode<Event.Heartbeat>()
9 -> decode<Event.InvalidSession>()
10 -> decode<Event.Hello>()
11 -> decode<Event.HeartbeatAck>()
else -> this

View file

@ -2,6 +2,4 @@ package moe.lava.neon.core.api.gateway.handlers
import moe.lava.neon.core.api.gateway.Event
sealed interface Handler<T: Event.Incoming> {
fun handle(event: T)
}
sealed interface Handler<T: Event.Incoming>

View file

@ -3,12 +3,18 @@ package moe.lava.neon.core.api.gateway.handlers
import co.touchlab.kermit.Logger
import dev.zacsweers.metro.Inject
import moe.lava.neon.core.api.gateway.Event
import moe.lava.neon.core.api.gateway.ResumeProperties
private val logger = Logger.withTag("neon.core.api.events/ready")
@Inject
class ReadyHandler : Handler<Event.Ready> {
override fun handle(event: Event.Ready) {
fun handle(event: Event.Ready, updateResumeProps: (ResumeProperties) -> Unit) {
logger.i { "Received payload $event" }
updateResumeProps(ResumeProperties(
sessionId = event.sessionId,
resumeGatewayUrl = event.resumeGatewayUrl,
lastSequence = 0,
))
}
}

View file

@ -19,6 +19,7 @@ import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.lifecycle.ViewModel
import androidx.lifecycle.viewModelScope
import co.touchlab.kermit.Logger
import dev.zacsweers.metro.AppScope
import dev.zacsweers.metro.ContributesIntoMap
import dev.zacsweers.metro.Inject
@ -83,16 +84,25 @@ class SampleViewModel(
private val auth: AuthRepository,
private val gateway: GatewayHandler,
) : ViewModel() {
private val logger = Logger.withTag("neon.ui.screens/Sample")
val token get() = auth.token
fun connect() {
viewModelScope.launch {
try {
gateway.connect()
} catch(e: Throwable) {
logger.e(e) { "Failed to connect to gateway: ${e.stackTraceToString()}" }
}
}
}
fun disconnect() {
viewModelScope.launch {
try {
gateway.disconnect()
} catch(e: Throwable) {
logger.e(e) { "Failed to connect to gateway: ${e.stackTraceToString()}" }
}
}
}
fun logout() {