You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

245 lines
7.1 KiB

import { dbGlobal } from "drizzle-pkg/lib/db";
import { sessions, users } from "drizzle-pkg/lib/schema/auth";
import { and, eq, gt, lte, sql } from "drizzle-orm";
import { isUniqueConflictExceptField, isUniqueConflictOnField } from "../../utils/db-unique-constraint";
import log4js from "logger";
import { randomUUID } from "crypto";
import { compare, hash } from "bcryptjs";
const logger = log4js.getLogger("AUTH")
const USERNAME_REGEX = /^[a-zA-Z0-9_]{3,20}$/;
const MIN_PASSWORD_LENGTH = 6;
const SESSION_EXPIRE_MS = 7 * 24 * 60 * 60 * 1000;
type AuthPayload = {
username: string;
password: string;
};
export type MinimalUser = {
id: number;
username: string;
role: string;
publicSlug: string | null;
nickname: string | null;
avatar: string | null;
};
export class AuthValidationError extends Error {
constructor(message: string) {
super(message);
this.name = "AuthValidationError";
}
}
export class AuthConflictError extends Error {
constructor(message: string) {
super(message);
this.name = "AuthConflictError";
}
}
export class AuthFailedError extends Error {
constructor(message: string) {
super(message);
this.name = "AuthFailedError";
}
}
function validateCredentials(payload: unknown): asserts payload is AuthPayload {
if (typeof payload !== "object" || payload === null) {
throw new AuthValidationError("用户名和密码必须是字符串");
}
const { username, password } = payload as Partial<AuthPayload>;
if (typeof username !== "string" || typeof password !== "string") {
throw new AuthValidationError("用户名和密码必须是字符串");
}
if (!USERNAME_REGEX.test(username)) {
throw new AuthValidationError("用户名格式不正确");
}
if (password.length < MIN_PASSWORD_LENGTH) {
throw new AuthValidationError("密码长度至少 6 位");
}
}
function authFailedError() {
return new AuthFailedError("用户名或密码错误");
}
async function createSession(userId: number) {
const sessionId = randomUUID();
const expiresAt = new Date(Date.now() + SESSION_EXPIRE_MS);
await dbGlobal.insert(sessions).values({
id: sessionId,
userId,
expiresAt,
});
return { sessionId, expiresAt };
}
async function getNextUserId() {
const [row] = await dbGlobal
.select({
maxId: sql<number>`COALESCE(MAX(${users.id}), 0)`,
})
.from(users);
return (row?.maxId ?? 0) + 1;
}
async function insertUserWithRetry(
username: string,
passwordHash: string,
email?: string | null,
): Promise<MinimalUser> {
const maxRetry = 5;
for (let attempt = 0; attempt < maxRetry; attempt++) {
const userId = await getNextUserId();
try {
const [newUser] = await dbGlobal
.insert(users)
.values({
id: userId,
username,
password: passwordHash,
email: email?.trim() || undefined,
})
.returning({
id: users.id,
username: users.username,
role: users.role,
publicSlug: users.publicSlug,
nickname: users.nickname,
avatar: users.avatar,
});
return newUser as MinimalUser;
} catch (err) {
if (isUniqueConflictOnField(err, "username")) {
throw new AuthConflictError("用户名已存在");
}
if (isUniqueConflictExceptField(err, "username") && attempt < maxRetry - 1) {
continue;
}
throw err;
}
}
throw new Error("创建用户失败,请稍后重试");
}
/** 管理员开号:角色与状态走表默认值(user / active) */
export async function adminProvisionUser(payload: {
username: string;
password: string;
email?: string | null;
}): Promise<MinimalUser> {
validateCredentials(payload);
const passwordHash = await hash(payload.password, 10);
const user = await insertUserWithRetry(payload.username, passwordHash, payload.email);
logger.info("user provisioned by admin: %s", payload.username);
return user;
}
export async function registerUser(payload: AuthPayload): Promise<MinimalUser> {
validateCredentials(payload);
const { username, password } = payload;
const passwordHash = await hash(password, 10);
const newUser = await insertUserWithRetry(username, passwordHash, null);
logger.info("user registered: %s", username);
return newUser;
}
export async function loginUser(payload: AuthPayload) {
validateCredentials(payload);
const { username, password } = payload;
const [user] = await dbGlobal
.select({
id: users.id,
username: users.username,
password: users.password,
status: users.status,
role: users.role,
publicSlug: users.publicSlug,
nickname: users.nickname,
avatar: users.avatar,
})
.from(users)
.where(eq(users.username, username));
if (!user) {
throw authFailedError();
}
if (user.status !== "active") {
throw authFailedError();
}
const isMatch = await compare(password, user.password);
if (!isMatch) {
throw authFailedError();
}
const { sessionId, expiresAt } = await createSession(user.id);
logger.info("user login: %s", username);
return {
user: {
id: user.id,
username: user.username,
role: user.role,
publicSlug: user.publicSlug,
nickname: user.nickname,
avatar: user.avatar,
} satisfies MinimalUser,
sessionId,
expiresAt,
};
}
export async function logoutUser(sessionId: string) {
await dbGlobal.delete(sessions).where(eq(sessions.id, sessionId));
logger.info("session logout");
return true;
}
export async function getCurrentUser(sessionId: string): Promise<MinimalUser | null> {
const now = new Date();
const [row] = await dbGlobal
.select({
userId: users.id,
username: users.username,
role: users.role,
publicSlug: users.publicSlug,
nickname: users.nickname,
avatar: users.avatar,
status: users.status,
expiresAt: sessions.expiresAt,
})
.from(sessions)
.innerJoin(users, eq(sessions.userId, users.id))
.where(and(eq(sessions.id, sessionId), gt(sessions.expiresAt, now)));
if (!row) {
await dbGlobal.delete(sessions).where(and(eq(sessions.id, sessionId), lte(sessions.expiresAt, now)));
return null;
}
if (row.status !== "active") {
await dbGlobal.delete(sessions).where(eq(sessions.id, sessionId));
return null;
}
return {
id: row.userId,
username: row.username,
role: row.role,
publicSlug: row.publicSlug,
nickname: row.nickname,
avatar: row.avatar,
};
}