|
|
|
@ -1,3 +1,4 @@ |
|
|
|
import log4js from "logger"; |
|
|
|
import { dbGlobal } from 'drizzle-pkg/lib/db'; |
|
|
|
import { oauthAccounts } from 'drizzle-pkg/lib/schema/auth'; |
|
|
|
import { eq, and } from 'drizzle-orm'; |
|
|
|
@ -6,6 +7,8 @@ import { OAuthError, OAuthErrorCodes } from './oauth-error'; |
|
|
|
import { registerUser, createSession } from '#server/service/auth'; |
|
|
|
import type { MinimalUser } from '#server/service/auth'; |
|
|
|
|
|
|
|
const logger = log4js.getLogger("OAUTH"); |
|
|
|
|
|
|
|
export type OAuthBinding = { |
|
|
|
id: number; |
|
|
|
provider: string; |
|
|
|
@ -77,6 +80,7 @@ export class OAuthManager { |
|
|
|
getAuthorizationUrl(providerName: string, userId?: number): string { |
|
|
|
const provider = providers[providerName]; |
|
|
|
if (!provider) { |
|
|
|
logger.warn(`[getAuthorizationUrl] Provider not found: ${providerName}`); |
|
|
|
throw new OAuthError( |
|
|
|
OAuthErrorCodes.PROVIDER_NOT_FOUND, |
|
|
|
`Provider ${providerName} not found` |
|
|
|
@ -89,6 +93,8 @@ export class OAuthManager { |
|
|
|
|
|
|
|
const { state } = this.stateStore.generate(providerName, redirectUri, userId); |
|
|
|
|
|
|
|
logger.info(`[getAuthorizationUrl] Generating auth URL for provider: ${providerName}, userId: ${userId ?? 'anonymous'}, redirectUri: ${redirectUri}`); |
|
|
|
|
|
|
|
const params = new URLSearchParams({ |
|
|
|
client_id: provider.clientId, |
|
|
|
redirect_uri: provider.redirectUri, |
|
|
|
@ -104,8 +110,11 @@ export class OAuthManager { |
|
|
|
code: string, |
|
|
|
state: string |
|
|
|
): Promise<OAuthCallbackResult> { |
|
|
|
logger.info(`[handleCallback] Provider: ${providerName}`); |
|
|
|
|
|
|
|
const provider = providers[providerName]; |
|
|
|
if (!provider) { |
|
|
|
logger.warn(`[handleCallback] Provider not found: ${providerName}`); |
|
|
|
throw new OAuthError( |
|
|
|
OAuthErrorCodes.PROVIDER_NOT_FOUND, |
|
|
|
`Provider ${providerName} not found` |
|
|
|
@ -114,6 +123,7 @@ export class OAuthManager { |
|
|
|
|
|
|
|
const oauthState = this.stateStore.consume(state); |
|
|
|
if (!oauthState) { |
|
|
|
logger.warn(`[handleCallback] Invalid or expired state`); |
|
|
|
throw new OAuthError( |
|
|
|
OAuthErrorCodes.STATE_INVALID, |
|
|
|
'Invalid or expired state' |
|
|
|
@ -121,6 +131,7 @@ export class OAuthManager { |
|
|
|
} |
|
|
|
|
|
|
|
if (oauthState.providerName !== providerName) { |
|
|
|
logger.warn(`[handleCallback] State provider mismatch: expected ${oauthState.providerName}, got ${providerName}`); |
|
|
|
throw new OAuthError( |
|
|
|
OAuthErrorCodes.STATE_INVALID, |
|
|
|
'State provider mismatch' |
|
|
|
@ -129,6 +140,7 @@ export class OAuthManager { |
|
|
|
|
|
|
|
const tokenResponse = await this.exchangeToken(provider, code); |
|
|
|
if (!tokenResponse.access_token) { |
|
|
|
logger.error(`[handleCallback] Token exchange failed for provider: ${providerName}`); |
|
|
|
throw new OAuthError( |
|
|
|
OAuthErrorCodes.TOKEN_EXCHANGE_FAILED, |
|
|
|
'Failed to exchange token' |
|
|
|
@ -137,16 +149,20 @@ export class OAuthManager { |
|
|
|
|
|
|
|
const userInfo = await this.getUserInfo(provider, tokenResponse.access_token); |
|
|
|
if (!userInfo) { |
|
|
|
logger.error(`[handleCallback] Failed to get user info from provider: ${providerName}`); |
|
|
|
throw new OAuthError( |
|
|
|
OAuthErrorCodes.USER_INFO_FAILED, |
|
|
|
'Failed to get user info' |
|
|
|
); |
|
|
|
} |
|
|
|
|
|
|
|
logger.info(`[handleCallback] User info retrieved: provider=${providerName}, userId=${userInfo.providerUserId}, username=${userInfo.username}`); |
|
|
|
|
|
|
|
const existingAccount = await this.findOAuthAccount(providerName, userInfo.providerUserId); |
|
|
|
|
|
|
|
if (oauthState.isBinding) { |
|
|
|
if (existingAccount && existingAccount.userId !== oauthState.userId) { |
|
|
|
logger.warn(`[handleCallback] OAuth account already bound to another user: provider=${providerName}, providerUserId=${userInfo.providerUserId}`); |
|
|
|
throw new OAuthError( |
|
|
|
OAuthErrorCodes.ALREADY_BIND, |
|
|
|
'This OAuth account is already bound to another user' |
|
|
|
@ -162,6 +178,8 @@ export class OAuthManager { |
|
|
|
avatar: userInfo.avatar, |
|
|
|
}); |
|
|
|
|
|
|
|
logger.info(`[handleCallback] OAuth account bound successfully: userId=${oauthState.userId}, provider=${providerName}`); |
|
|
|
|
|
|
|
return { |
|
|
|
success: true, |
|
|
|
isNewUser: false, |
|
|
|
@ -171,6 +189,7 @@ export class OAuthManager { |
|
|
|
|
|
|
|
if (existingAccount) { |
|
|
|
const { sessionId, expiresAt } = await createSession(existingAccount.userId); |
|
|
|
logger.info(`[handleCallback] Existing user logged in: userId=${existingAccount.userId}, provider=${providerName}`); |
|
|
|
return { |
|
|
|
success: true, |
|
|
|
isNewUser: false, |
|
|
|
@ -196,6 +215,8 @@ export class OAuthManager { |
|
|
|
|
|
|
|
const { sessionId, expiresAt } = await createSession(newUser.id); |
|
|
|
|
|
|
|
logger.info(`[handleCallback] New user registered and logged in: userId=${newUser.id}, provider=${providerName}`); |
|
|
|
|
|
|
|
return { |
|
|
|
success: true, |
|
|
|
isNewUser: true, |
|
|
|
@ -209,11 +230,13 @@ export class OAuthManager { |
|
|
|
async unbindAccount(userId: number, providerName: string): Promise<void> { |
|
|
|
const deleted = await this.deleteOAuthAccount(userId, providerName); |
|
|
|
if (!deleted) { |
|
|
|
logger.warn(`[unbindAccount] OAuth account not found: userId=${userId}, provider=${providerName}`); |
|
|
|
throw new OAuthError( |
|
|
|
OAuthErrorCodes.BINDING_USER_MISMATCH, |
|
|
|
'OAuth account not found' |
|
|
|
); |
|
|
|
} |
|
|
|
logger.info(`[unbindAccount] OAuth account unbound: userId=${userId}, provider=${providerName}`); |
|
|
|
} |
|
|
|
|
|
|
|
async getUserBindings(userId: number): Promise<OAuthBinding[]> { |
|
|
|
@ -224,7 +247,7 @@ export class OAuthManager { |
|
|
|
username: account.username, |
|
|
|
email: account.email, |
|
|
|
avatar: account.avatar, |
|
|
|
boundAt: account.createdAt, |
|
|
|
boundAt: account.createdAt ?? new Date(), |
|
|
|
})); |
|
|
|
} |
|
|
|
|
|
|
|
@ -325,7 +348,8 @@ export class OAuthManager { |
|
|
|
eq(oauthAccounts.userId, userId), |
|
|
|
eq(oauthAccounts.provider, provider) |
|
|
|
)); |
|
|
|
return (result.rowCount ?? 0) > 0; |
|
|
|
const info = result as unknown as { changes?: number; rowsAffected?: number }; |
|
|
|
return (info.changes ?? info.rowsAffected ?? 0) > 0; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|