You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

235 lines
6.4 KiB

import { dbGlobal } from "drizzle-pkg/lib/db";
import { sessions, users } from "drizzle-pkg/lib/schema/auth";
import { and, eq, gt, sql } from "drizzle-orm";
import log4js from "logger";
import { randomUUID } from "crypto";
import { compare, hash } from "bcryptjs";
const logger = log4js.getLogger("AUTH")
const USERNAME_REGEX = /^[a-zA-Z0-9_]{3,20}$/;
const MIN_PASSWORD_LENGTH = 6;
const SESSION_EXPIRE_MS = 7 * 24 * 60 * 60 * 1000;
type AuthPayload = {
username: string;
password: string;
};
type MinimalUser = {
id: number;
username: string;
};
export class AuthValidationError extends Error {
constructor(message: string) {
super(message);
this.name = "AuthValidationError";
}
}
export class AuthConflictError extends Error {
constructor(message: string) {
super(message);
this.name = "AuthConflictError";
}
}
export class AuthFailedError extends Error {
constructor(message: string) {
super(message);
this.name = "AuthFailedError";
}
}
function validateCredentials(payload: unknown): asserts payload is AuthPayload {
if (typeof payload !== "object" || payload === null) {
throw new AuthValidationError("用户名和密码必须是字符串");
}
const { username, password } = payload as Partial<AuthPayload>;
if (typeof username !== "string" || typeof password !== "string") {
throw new AuthValidationError("用户名和密码必须是字符串");
}
if (!USERNAME_REGEX.test(username)) {
throw new AuthValidationError("用户名格式不正确");
}
if (password.length < MIN_PASSWORD_LENGTH) {
throw new AuthValidationError("密码长度至少 6 位");
}
}
function authFailedError() {
return new AuthFailedError("用户名或密码错误");
}
function unwrapDbError(err: unknown): unknown {
if (!(err instanceof Error)) {
return err;
}
if (!("cause" in err)) {
return err;
}
const cause = (err as { cause?: unknown }).cause;
return cause ?? err;
}
async function createSession(userId: number) {
const sessionId = randomUUID();
const expiresAt = new Date(Date.now() + SESSION_EXPIRE_MS);
await dbGlobal.insert(sessions).values({
id: sessionId,
userId,
expiresAt,
});
return { sessionId, expiresAt };
}
async function getNextUserId() {
const [row] = await dbGlobal
.select({
maxId: sql<number>`COALESCE(MAX(${users.id}), 0)`,
})
.from(users);
return (row?.maxId ?? 0) + 1;
}
function isPgUniqueViolation(err: unknown) {
const dbError = unwrapDbError(err);
if (!(dbError instanceof Error)) {
return false;
}
return "code" in dbError && (dbError as { code?: string }).code === "23505";
}
function getPgConstraint(err: unknown) {
const dbError = unwrapDbError(err);
if (!(dbError instanceof Error)) {
return "";
}
if (!("constraint" in dbError)) {
return "";
}
return ((dbError as { constraint?: string }).constraint ?? "").toLowerCase();
}
function isUsernameConflict(err: unknown) {
return isPgUniqueViolation(err) && getPgConstraint(err).includes("username");
}
function isUserIdConflict(err: unknown) {
if (!isPgUniqueViolation(err)) {
return false;
}
const constraint = getPgConstraint(err);
return constraint.includes("pkey") || constraint.includes("id");
}
async function insertUserWithRetry(username: string, passwordHash: string): Promise<MinimalUser> {
const maxRetry = 5;
for (let attempt = 0; attempt < maxRetry; attempt++) {
const userId = await getNextUserId();
try {
const [newUser] = await dbGlobal
.insert(users)
.values({
id: userId,
username,
password: passwordHash,
})
.returning({
id: users.id,
username: users.username,
});
return newUser as MinimalUser;
} catch (err) {
if (isUsernameConflict(err)) {
throw new AuthConflictError("用户名已存在");
}
if (isUserIdConflict(err) && attempt < maxRetry - 1) {
continue;
}
throw err;
}
}
throw new Error("创建用户失败,请稍后重试");
}
export async function registerUser(payload: AuthPayload): Promise<MinimalUser> {
validateCredentials(payload);
const { username, password } = payload;
const passwordHash = await hash(password, 10);
const newUser = await insertUserWithRetry(username, passwordHash);
logger.info("user registered: %s", username);
return newUser;
}
export async function loginUser(payload: AuthPayload) {
validateCredentials(payload);
const { username, password } = payload;
const [user] = await dbGlobal
.select({
id: users.id,
username: users.username,
password: users.password,
})
.from(users)
.where(eq(users.username, username));
if (!user) {
throw authFailedError();
}
const isMatch = await compare(password, user.password);
if (!isMatch) {
throw authFailedError();
}
const { sessionId, expiresAt } = await createSession(user.id);
logger.info("user login: %s", username);
return {
user: {
id: user.id,
username: user.username,
} satisfies MinimalUser,
sessionId,
expiresAt,
};
}
export async function logoutUser(sessionId: string) {
await dbGlobal.delete(sessions).where(eq(sessions.id, sessionId));
logger.info("session logout");
return true;
}
export async function getCurrentUser(sessionId: string): Promise<MinimalUser | null> {
const now = new Date();
const [row] = await dbGlobal
.select({
userId: users.id,
username: users.username,
expiresAt: sessions.expiresAt,
})
.from(sessions)
.innerJoin(users, eq(sessions.userId, users.id))
.where(and(eq(sessions.id, sessionId), gt(sessions.expiresAt, now)));
if (!row) {
await dbGlobal
.delete(sessions)
.where(and(eq(sessions.id, sessionId), sql`${sessions.expiresAt} <= NOW()`));
return null;
}
return {
id: row.userId,
username: row.username,
};
}