You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
193 lines
5.3 KiB
193 lines
5.3 KiB
import { dbGlobal } from "drizzle-pkg/lib/db";
|
|
import { sessions, users } from "drizzle-pkg/lib/schema/auth";
|
|
import { and, eq, gt, sql } from "drizzle-orm";
|
|
import log4js from "logger";
|
|
import { randomUUID } from "crypto";
|
|
import { compare, hash } from "bcryptjs";
|
|
|
|
const logger = log4js.getLogger("AUTH")
|
|
|
|
const USERNAME_REGEX = /^[a-zA-Z0-9_]{3,20}$/;
|
|
const MIN_PASSWORD_LENGTH = 6;
|
|
const SESSION_EXPIRE_MS = 7 * 24 * 60 * 60 * 1000;
|
|
|
|
type AuthPayload = {
|
|
username: string;
|
|
password: string;
|
|
};
|
|
|
|
type MinimalUser = {
|
|
id: number;
|
|
username: string;
|
|
};
|
|
|
|
export class AuthValidationError extends Error {
|
|
constructor(message: string) {
|
|
super(message);
|
|
this.name = "AuthValidationError";
|
|
}
|
|
}
|
|
|
|
export class AuthConflictError extends Error {
|
|
constructor(message: string) {
|
|
super(message);
|
|
this.name = "AuthConflictError";
|
|
}
|
|
}
|
|
|
|
export class AuthFailedError extends Error {
|
|
constructor(message: string) {
|
|
super(message);
|
|
this.name = "AuthFailedError";
|
|
}
|
|
}
|
|
|
|
function validateCredentials(payload: unknown): asserts payload is AuthPayload {
|
|
if (typeof payload !== "object" || payload === null) {
|
|
throw new AuthValidationError("用户名和密码必须是字符串");
|
|
}
|
|
|
|
const { username, password } = payload as Partial<AuthPayload>;
|
|
|
|
if (typeof username !== "string" || typeof password !== "string") {
|
|
throw new AuthValidationError("用户名和密码必须是字符串");
|
|
}
|
|
|
|
if (!USERNAME_REGEX.test(username)) {
|
|
throw new AuthValidationError("用户名格式不正确");
|
|
}
|
|
if (password.length < MIN_PASSWORD_LENGTH) {
|
|
throw new AuthValidationError("密码长度至少 6 位");
|
|
}
|
|
}
|
|
|
|
function authFailedError() {
|
|
return new AuthFailedError("用户名或密码错误");
|
|
}
|
|
|
|
async function createSession(userId: number) {
|
|
const sessionId = randomUUID();
|
|
const expiresAt = new Date(Date.now() + SESSION_EXPIRE_MS);
|
|
await dbGlobal.insert(sessions).values({
|
|
id: sessionId,
|
|
userId,
|
|
expiresAt,
|
|
});
|
|
return { sessionId, expiresAt };
|
|
}
|
|
|
|
async function getNextUserId() {
|
|
const [row] = await dbGlobal
|
|
.select({
|
|
maxId: sql<number>`COALESCE(MAX(${users.id}), 0)`,
|
|
})
|
|
.from(users);
|
|
return (row?.maxId ?? 0) + 1;
|
|
}
|
|
|
|
async function insertUserWithRetry(username: string, passwordHash: string): Promise<MinimalUser> {
|
|
const maxRetry = 5;
|
|
for (let attempt = 0; attempt < maxRetry; attempt++) {
|
|
const userId = await getNextUserId();
|
|
try {
|
|
const [newUser] = await dbGlobal
|
|
.insert(users)
|
|
.values({
|
|
id: userId,
|
|
username,
|
|
password: passwordHash,
|
|
})
|
|
.returning({
|
|
id: users.id,
|
|
username: users.username,
|
|
});
|
|
return newUser as MinimalUser;
|
|
} catch (err) {
|
|
if (isUniqueConflictOnField(err, "username")) {
|
|
throw new AuthConflictError("用户名已存在");
|
|
}
|
|
if (isUniqueConflictExceptField(err, "username") && attempt < maxRetry - 1) {
|
|
continue;
|
|
}
|
|
throw err;
|
|
}
|
|
}
|
|
throw new Error("创建用户失败,请稍后重试");
|
|
}
|
|
|
|
export async function registerUser(payload: AuthPayload): Promise<MinimalUser> {
|
|
validateCredentials(payload);
|
|
const { username, password } = payload;
|
|
|
|
const passwordHash = await hash(password, 10);
|
|
const newUser = await insertUserWithRetry(username, passwordHash);
|
|
logger.info("user registered: %s", username);
|
|
|
|
return newUser;
|
|
}
|
|
|
|
export async function loginUser(payload: AuthPayload) {
|
|
validateCredentials(payload);
|
|
const { username, password } = payload;
|
|
|
|
const [user] = await dbGlobal
|
|
.select({
|
|
id: users.id,
|
|
username: users.username,
|
|
password: users.password,
|
|
})
|
|
.from(users)
|
|
.where(eq(users.username, username));
|
|
|
|
if (!user) {
|
|
throw authFailedError();
|
|
}
|
|
|
|
const isMatch = await compare(password, user.password);
|
|
if (!isMatch) {
|
|
throw authFailedError();
|
|
}
|
|
|
|
const { sessionId, expiresAt } = await createSession(user.id);
|
|
logger.info("user login: %s", username);
|
|
|
|
return {
|
|
user: {
|
|
id: user.id,
|
|
username: user.username,
|
|
} satisfies MinimalUser,
|
|
sessionId,
|
|
expiresAt,
|
|
};
|
|
}
|
|
|
|
export async function logoutUser(sessionId: string) {
|
|
await dbGlobal.delete(sessions).where(eq(sessions.id, sessionId));
|
|
logger.info("session logout");
|
|
return true;
|
|
}
|
|
|
|
export async function getCurrentUser(sessionId: string): Promise<MinimalUser | null> {
|
|
const now = new Date();
|
|
const [row] = await dbGlobal
|
|
.select({
|
|
userId: users.id,
|
|
username: users.username,
|
|
expiresAt: sessions.expiresAt,
|
|
})
|
|
.from(sessions)
|
|
.innerJoin(users, eq(sessions.userId, users.id))
|
|
.where(and(eq(sessions.id, sessionId), gt(sessions.expiresAt, now)));
|
|
|
|
if (!row) {
|
|
await dbGlobal
|
|
.delete(sessions)
|
|
.where(and(eq(sessions.id, sessionId), sql`${sessions.expiresAt} <= NOW()`));
|
|
return null;
|
|
}
|
|
|
|
return {
|
|
id: row.userId,
|
|
username: row.username,
|
|
};
|
|
}
|