import { dbGlobal } from "drizzle-pkg/lib/db"; import { sessions, users } from "drizzle-pkg/lib/schema/auth"; import { and, eq, gt, sql } from "drizzle-orm"; import log4js from "logger"; import { randomUUID } from "crypto"; import { compare, hash } from "bcryptjs"; const logger = log4js.getLogger("AUTH") const USERNAME_REGEX = /^[a-zA-Z0-9_]{3,20}$/; const MIN_PASSWORD_LENGTH = 6; const SESSION_EXPIRE_MS = 7 * 24 * 60 * 60 * 1000; type AuthPayload = { username: string; password: string; }; type MinimalUser = { id: number; username: string; }; export class AuthValidationError extends Error { constructor(message: string) { super(message); this.name = "AuthValidationError"; } } export class AuthConflictError extends Error { constructor(message: string) { super(message); this.name = "AuthConflictError"; } } export class AuthFailedError extends Error { constructor(message: string) { super(message); this.name = "AuthFailedError"; } } function validateCredentials(payload: unknown): asserts payload is AuthPayload { if (typeof payload !== "object" || payload === null) { throw new AuthValidationError("用户名和密码必须是字符串"); } const { username, password } = payload as Partial; if (typeof username !== "string" || typeof password !== "string") { throw new AuthValidationError("用户名和密码必须是字符串"); } if (!USERNAME_REGEX.test(username)) { throw new AuthValidationError("用户名格式不正确"); } if (password.length < MIN_PASSWORD_LENGTH) { throw new AuthValidationError("密码长度至少 6 位"); } } function authFailedError() { return new AuthFailedError("用户名或密码错误"); } function unwrapDbError(err: unknown): unknown { if (!(err instanceof Error)) { return err; } if (!("cause" in err)) { return err; } const cause = (err as { cause?: unknown }).cause; return cause ?? err; } async function createSession(userId: number) { const sessionId = randomUUID(); const expiresAt = new Date(Date.now() + SESSION_EXPIRE_MS); await dbGlobal.insert(sessions).values({ id: sessionId, userId, expiresAt, }); return { sessionId, expiresAt }; } async function getNextUserId() { const [row] = await dbGlobal .select({ maxId: sql`COALESCE(MAX(${users.id}), 0)`, }) .from(users); return (row?.maxId ?? 0) + 1; } function isPgUniqueViolation(err: unknown) { const dbError = unwrapDbError(err); if (!(dbError instanceof Error)) { return false; } return "code" in dbError && (dbError as { code?: string }).code === "23505"; } function getPgConstraint(err: unknown) { const dbError = unwrapDbError(err); if (!(dbError instanceof Error)) { return ""; } if (!("constraint" in dbError)) { return ""; } return ((dbError as { constraint?: string }).constraint ?? "").toLowerCase(); } function isUsernameConflict(err: unknown) { return isPgUniqueViolation(err) && getPgConstraint(err).includes("username"); } function isUserIdConflict(err: unknown) { if (!isPgUniqueViolation(err)) { return false; } const constraint = getPgConstraint(err); return constraint.includes("pkey") || constraint.includes("id"); } async function insertUserWithRetry(username: string, passwordHash: string): Promise { const maxRetry = 5; for (let attempt = 0; attempt < maxRetry; attempt++) { const userId = await getNextUserId(); try { const [newUser] = await dbGlobal .insert(users) .values({ id: userId, username, password: passwordHash, }) .returning({ id: users.id, username: users.username, }); return newUser as MinimalUser; } catch (err) { if (isUsernameConflict(err)) { throw new AuthConflictError("用户名已存在"); } if (isUserIdConflict(err) && attempt < maxRetry - 1) { continue; } throw err; } } throw new Error("创建用户失败,请稍后重试"); } export async function registerUser(payload: AuthPayload): Promise { validateCredentials(payload); const { username, password } = payload; const passwordHash = await hash(password, 10); const newUser = await insertUserWithRetry(username, passwordHash); logger.info("user registered: %s", username); return newUser; } export async function loginUser(payload: AuthPayload) { validateCredentials(payload); const { username, password } = payload; const [user] = await dbGlobal .select({ id: users.id, username: users.username, password: users.password, }) .from(users) .where(eq(users.username, username)); if (!user) { throw authFailedError(); } const isMatch = await compare(password, user.password); if (!isMatch) { throw authFailedError(); } const { sessionId, expiresAt } = await createSession(user.id); logger.info("user login: %s", username); return { user: { id: user.id, username: user.username, } satisfies MinimalUser, sessionId, expiresAt, }; } export async function logoutUser(sessionId: string) { await dbGlobal.delete(sessions).where(eq(sessions.id, sessionId)); logger.info("session logout"); return true; } export async function getCurrentUser(sessionId: string): Promise { const now = new Date(); const [row] = await dbGlobal .select({ userId: users.id, username: users.username, expiresAt: sessions.expiresAt, }) .from(sessions) .innerJoin(users, eq(sessions.userId, users.id)) .where(and(eq(sessions.id, sessionId), gt(sessions.expiresAt, now))); if (!row) { await dbGlobal .delete(sessions) .where(and(eq(sessions.id, sessionId), sql`${sessions.expiresAt} <= NOW()`)); return null; } return { id: row.userId, username: row.username, }; }