diff --git a/app/lib/access_token_extension.rb b/app/lib/access_token_extension.rb new file mode 100644 index 000000000..3e184e775 --- /dev/null +++ b/app/lib/access_token_extension.rb @@ -0,0 +1,17 @@ +# frozen_string_literal: true + +module AccessTokenExtension + extend ActiveSupport::Concern + + included do + after_commit :push_to_streaming_api + end + + def revoke(clock = Time) + update(revoked_at: clock.now.utc) + end + + def push_to_streaming_api + Redis.current.publish("timeline:access_token:#{id}", Oj.dump(event: :kill)) if revoked? || destroyed? + end +end diff --git a/app/models/session_activation.rb b/app/models/session_activation.rb index 34d25c83d..b0ce9d112 100644 --- a/app/models/session_activation.rb +++ b/app/models/session_activation.rb @@ -70,12 +70,16 @@ class SessionActivation < ApplicationRecord end def assign_access_token - superapp = Doorkeeper::Application.find_by(superapp: true) + self.access_token = Doorkeeper::AccessToken.create!(access_token_attributes) + end - self.access_token = Doorkeeper::AccessToken.create!(application_id: superapp&.id, - resource_owner_id: user_id, - scopes: 'read write follow', - expires_in: Doorkeeper.configuration.access_token_expires_in, - use_refresh_token: Doorkeeper.configuration.refresh_token_enabled?) + def access_token_attributes + { + application_id: Doorkeeper::Application.find_by(superapp: true)&.id, + resource_owner_id: user_id, + scopes: 'read write follow', + expires_in: Doorkeeper.configuration.access_token_expires_in, + use_refresh_token: Doorkeeper.configuration.refresh_token_enabled?, + } end end diff --git a/config/application.rb b/config/application.rb index de2951487..af7735221 100644 --- a/config/application.rb +++ b/config/application.rb @@ -140,6 +140,7 @@ module Mastodon Doorkeeper::AuthorizationsController.layout 'modal' Doorkeeper::AuthorizedApplicationsController.layout 'admin' Doorkeeper::Application.send :include, ApplicationExtension + Doorkeeper::AccessToken.send :include, AccessTokenExtension Devise::FailureApp.send :include, AbstractController::Callbacks Devise::FailureApp.send :include, HttpAcceptLanguage::EasyAccess Devise::FailureApp.send :include, Localized diff --git a/streaming/index.js b/streaming/index.js index 791a26941..877f45d19 100644 --- a/streaming/index.js +++ b/streaming/index.js @@ -294,7 +294,7 @@ const startWorker = (workerId) => { return; } - client.query('SELECT oauth_access_tokens.resource_owner_id, users.account_id, users.chosen_languages, oauth_access_tokens.scopes, devices.device_id FROM oauth_access_tokens INNER JOIN users ON oauth_access_tokens.resource_owner_id = users.id LEFT OUTER JOIN devices ON oauth_access_tokens.id = devices.access_token_id WHERE oauth_access_tokens.token = $1 AND oauth_access_tokens.revoked_at IS NULL LIMIT 1', [token], (err, result) => { + client.query('SELECT oauth_access_tokens.id, oauth_access_tokens.resource_owner_id, users.account_id, users.chosen_languages, oauth_access_tokens.scopes, devices.device_id FROM oauth_access_tokens INNER JOIN users ON oauth_access_tokens.resource_owner_id = users.id LEFT OUTER JOIN devices ON oauth_access_tokens.id = devices.access_token_id WHERE oauth_access_tokens.token = $1 AND oauth_access_tokens.revoked_at IS NULL LIMIT 1', [token], (err, result) => { done(); if (err) { @@ -310,6 +310,7 @@ const startWorker = (workerId) => { return; } + req.accessTokenId = result.rows[0].id; req.scopes = result.rows[0].scopes.split(' '); req.accountId = result.rows[0].account_id; req.chosenLanguages = result.rows[0].chosen_languages; @@ -450,6 +451,55 @@ const startWorker = (workerId) => { }); }; + /** + * @typedef SystemMessageHandlers + * @property {function(): void} onKill + */ + + /** + * @param {any} req + * @param {SystemMessageHandlers} eventHandlers + * @return {function(string): void} + */ + const createSystemMessageListener = (req, eventHandlers) => { + return message => { + const json = parseJSON(message); + + if (!json) return; + + const { event } = json; + + log.silly(req.requestId, `System message for ${req.accountId}: ${event}`); + + if (event === 'kill') { + log.verbose(req.requestId, `Closing connection for ${req.accountId} due to expired access token`); + eventHandlers.onKill(); + } + } + }; + + /** + * @param {any} req + * @param {any} res + */ + const subscribeHttpToSystemChannel = (req, res) => { + const systemChannelId = `timeline:access_token:${req.accessTokenId}`; + + const listener = createSystemMessageListener(req, { + + onKill () { + res.end(); + }, + + }); + + res.on('close', () => { + unsubscribe(`${redisPrefix}${systemChannelId}`, listener); + }); + + subscribe(`${redisPrefix}${systemChannelId}`, listener); + }; + /** * @param {any} req * @param {any} res @@ -462,6 +512,8 @@ const startWorker = (workerId) => { } accountFromRequest(req, alwaysRequireAuth).then(() => checkScopes(req, channelNameFromPath(req))).then(() => { + subscribeHttpToSystemChannel(req, res); + }).then(() => { next(); }).catch(err => { next(err); @@ -536,7 +588,9 @@ const startWorker = (workerId) => { const listener = message => { const json = parseJSON(message); + if (!json) return; + const { event, payload, queued_at } = json; const transmit = () => { @@ -902,6 +956,28 @@ const startWorker = (workerId) => { socket.send(JSON.stringify({ error: err.toString() })); }); + /** + * @param {WebSocketSession} session + */ + const subscribeWebsocketToSystemChannel = ({ socket, request, subscriptions }) => { + const systemChannelId = `timeline:access_token:${request.accessTokenId}`; + + const listener = createSystemMessageListener(request, { + + onKill () { + socket.close(); + }, + + }); + + subscribe(`${redisPrefix}${systemChannelId}`, listener); + + subscriptions[systemChannelId] = { + listener, + stopHeartbeat: () => {}, + }; + }; + /** * @param {string|string[]} arrayOrString * @return {string} @@ -948,7 +1024,9 @@ const startWorker = (workerId) => { ws.on('message', data => { const json = parseJSON(data); + if (!json) return; + const { type, stream, ...params } = json; if (type === 'subscribe') { @@ -960,6 +1038,8 @@ const startWorker = (workerId) => { } }); + subscribeWebsocketToSystemChannel(session); + if (location.query.stream) { subscribeWebsocketToChannel(session, firstParam(location.query.stream), location.query); }