diff --git a/packages/api/src/auth/authProvider.js b/packages/api/src/auth/authProvider.js index 4f8852d73..7503feee4 100644 --- a/packages/api/src/auth/authProvider.js +++ b/packages/api/src/auth/authProvider.js @@ -9,6 +9,8 @@ const jwt = require('jsonwebtoken'); const logger = getLogger('authProvider'); class AuthProviderBase { + amoid = 'none'; + async login(login, password, options = undefined) { return {}; } @@ -51,9 +53,17 @@ class AuthProviderBase { getSingleConnectionId(req) { return null; } + + toJson() { + return { + amoid: this.amoid, + }; + } } class OAuthProvider extends AuthProviderBase { + amoid = 'oauth'; + shouldAuthorizeApi() { return true; } @@ -120,6 +130,8 @@ class OAuthProvider extends AuthProviderBase { } class ADProvider extends AuthProviderBase { + amoid = 'ad'; + async login(login, password) { const adConfig = { url: process.env.AD_URL, @@ -157,6 +169,8 @@ class ADProvider extends AuthProviderBase { } class LoginsProvider extends AuthProviderBase { + amoid = 'logins'; + async login(login, password) { if (password == process.env[`LOGIN_PASSWORD_${login}`]) { return { @@ -176,6 +190,8 @@ class LoginsProvider extends AuthProviderBase { } class DenyAllProvider extends AuthProviderBase { + amoid = 'deny'; + shouldAuthorizeApi() { return true; } @@ -233,19 +249,37 @@ function createEnvAuthProvider() { } } -let authProvider = createEnvAuthProvider(); +let defaultAuthProvider = createEnvAuthProvider(); +let authProviders = [defaultAuthProvider]; -function getAuthProvider() { - return authProvider; +function getAuthProviders() { + return authProviders; } -function setAuthProvider(value) { - authProvider = value; +function getAuthProviderById(amoid) { + return authProviders.find(x => x.amoid == amoid); +} + +function getDefaultAuthProvider() { + return defaultAuthProvider; +} + +function getAuthProviderFromReq(req) { + const authProviderId = req?.auth?.amoid || req?.user?.amoid; + return getAuthProviderById(authProviderId) ?? getDefaultAuthProvider(); +} + +function setAuthProviders(value, defaultProvider = null) { + authProviders = value; + defaultAuthProvider = defaultProvider || value[0]; } module.exports = { AuthProviderBase, detectEnvAuthProvider, - getAuthProvider, - setAuthProvider, + getAuthProviders, + getDefaultAuthProvider, + setAuthProviders, + getAuthProviderById, + getAuthProviderFromReq, }; diff --git a/packages/api/src/controllers/auth.js b/packages/api/src/controllers/auth.js index e52a32ffe..95cf7281e 100644 --- a/packages/api/src/controllers/auth.js +++ b/packages/api/src/controllers/auth.js @@ -5,7 +5,7 @@ const { getLogger } = require('dbgate-tools'); const AD = require('activedirectory2').promiseWrapper; const crypto = require('crypto'); const { getTokenSecret, getTokenLifetime } = require('../auth/authCommon'); -const { getAuthProvider } = require('../auth/authProvider'); +const { getAuthProviderFromReq, getAuthProviders, getDefaultAuthProvider, getAuthProviderById } = require('../auth/authProvider'); const storage = require('./storage'); const logger = getLogger('auth'); @@ -28,6 +28,7 @@ function authMiddleware(req, res, next) { '/auth/login', '/stream', 'storage/get-connections-for-login-page', + 'auth/get-providers', '/connections/dblogin', '/connections/dblogin-auth', '/connections/dblogin-auth-token', @@ -37,7 +38,7 @@ function authMiddleware(req, res, next) { const isAdminPage = req.headers['x-is-admin-page'] == 'true'; - if (!isAdminPage && !getAuthProvider().shouldAuthorizeApi()) { + if (!isAdminPage && !getAuthProviderFromReq(req).shouldAuthorizeApi()) { return next(); } let skipAuth = !!SKIP_AUTH_PATHS.find(x => req.path == getExpressPath(x)); @@ -68,11 +69,11 @@ function authMiddleware(req, res, next) { module.exports = { oauthToken_meta: true, async oauthToken(params) { - return getAuthProvider().oauthToken(params); + return getDefaultAuthProvider().oauthToken(params); }, login_meta: true, async login(params) { - const { login, password, isAdminPage } = params; + const { amoid, login, password, isAdminPage } = params; if (isAdminPage) { if (process.env.ADMIN_PASSWORD && process.env.ADMIN_PASSWORD == password) { @@ -94,7 +95,15 @@ module.exports = { return { error: 'Login failed' }; } - return getAuthProvider().login(login, password); + return getAuthProviderById(amoid).login(login, password); + }, + + getProviders_meta: true, + getProviders() { + return { + providers: getAuthProviders().map(x => x.toJson()), + default: getDefaultAuthProvider()?.amoid, + }; }, authMiddleware, diff --git a/packages/api/src/controllers/config.js b/packages/api/src/controllers/config.js index f551a5153..85c3cdc46 100644 --- a/packages/api/src/controllers/config.js +++ b/packages/api/src/controllers/config.js @@ -11,7 +11,7 @@ const AsyncLock = require('async-lock'); const currentVersion = require('../currentVersion'); const platformInfo = require('../utility/platformInfo'); const connections = require('../controllers/connections'); -const { getAuthProvider } = require('../auth/authProvider'); +const { getAuthProviderFromReq } = require('../auth/authProvider'); const lock = new AsyncLock(); @@ -28,7 +28,7 @@ module.exports = { get_meta: true, async get(_params, req) { - const authProvider = getAuthProvider(); + const authProvider = getAuthProviderFromReq(req); const login = authProvider.getCurrentLogin(req); const permissions = authProvider.getCurrentPermissions(req); const isLoginForm = authProvider.isLoginForm(); diff --git a/packages/api/src/main.js b/packages/api/src/main.js index 731d882ae..2cd80986f 100644 --- a/packages/api/src/main.js +++ b/packages/api/src/main.js @@ -34,7 +34,7 @@ const platformInfo = require('./utility/platformInfo'); const getExpressPath = require('./utility/getExpressPath'); const _ = require('lodash'); const { getLogger } = require('dbgate-tools'); -const { getAuthProvider } = require('./auth/authProvider'); +const { getDefaultAuthProvider } = require('./auth/authProvider'); const logger = getLogger('main'); @@ -48,7 +48,7 @@ function start() { if (process.env.BASIC_AUTH) { async function authorizer(username, password, cb) { try { - const resp = await getAuthProvider().login(username, password); + const resp = await getDefaultAuthProvider().login(username, password); if (resp.accessToken) { cb(null, true); } else { diff --git a/packages/api/src/utility/hasPermission.js b/packages/api/src/utility/hasPermission.js index 1d6b63abf..654ba8ceb 100644 --- a/packages/api/src/utility/hasPermission.js +++ b/packages/api/src/utility/hasPermission.js @@ -1,6 +1,6 @@ const { compilePermissions, testPermission } = require('dbgate-tools'); const _ = require('lodash'); -const { getAuthProvider } = require('../auth/authProvider'); +const { getAuthProviderFromReq } = require('../auth/authProvider'); const cachedPermissions = {}; @@ -10,7 +10,7 @@ function hasPermission(tested, req) { return true; } - const permissions = getAuthProvider().getCurrentPermissions(req); + const permissions = getAuthProviderFromReq(req).getCurrentPermissions(req); if (!cachedPermissions[permissions]) { cachedPermissions[permissions] = compilePermissions(permissions); diff --git a/packages/web/src/LoginPage.svelte b/packages/web/src/LoginPage.svelte index e76bd2419..a1b1f31dd 100644 --- a/packages/web/src/LoginPage.svelte +++ b/packages/web/src/LoginPage.svelte @@ -19,29 +19,47 @@ const config = useConfig(); let availableConnections = null; + let availableProviders = []; let isTesting = false; const testIdRef = createRef(0); let sqlConnectResult; - const values = writable({ databaseServer: null }); + let serversLoadedForAmoId = null; + + const values = writable({ amoid: null, databaseServer: null }); $: selectedConnection = availableConnections?.find(x => x.conid == $values.databaseServer); - async function loadAvailableServers() { - availableConnections = await apiCall('storage/get-connections-for-login-page'); - if (availableConnections?.length > 0) { - values.set({ databaseServer: availableConnections[0].conid }); + async function loadAvailableServers(amoid) { + if (amoid) { + availableConnections = await apiCall('storage/get-connections-for-login-page', { amoid }); + if (availableConnections?.length > 0) { + values.update(x => ({ ...x, databaseServer: availableConnections[0].conid })); + } + serversLoadedForAmoId = amoid; + } else { + availableConnections = null; } } + async function loadAvailableAuthProviders() { + const resp = await apiCall('auth/get-providers'); + availableProviders = resp.providers; + values.update(x => ({ ...x, amoid: resp.default })); + } + onMount(() => { const removed = document.getElementById('starting_dbgate_zero'); if (removed) removed.remove(); if (!isAdminPage) { - loadAvailableServers(); + loadAvailableAuthProviders(); } }); + + $: if ($values.amoid != serversLoadedForAmoId) { + loadAvailableServers($values.amoid); + }