import fs from "node:fs/promises"; import path from "node:path"; import { dbGlobal } from "drizzle-pkg/lib/db"; import { userExportTasks } from "drizzle-pkg/lib/schema/export"; import { and, desc, eq, gt, isNotNull, or, sql } from "drizzle-orm"; import { nextIntegerId } from "../../utils/sqlite-id"; import { RELATIVE_TMP_DIR } from "#server/constants/media"; type ExportMaskPolicy = "masked" | "raw"; function positiveIntFromEnv(name: string, fallback: number): number { const raw = process.env[name]; if (!raw) { return fallback; } const n = Number(raw); return Number.isInteger(n) && n > 0 ? n : fallback; } function positiveBigIntFromEnv(name: string, fallback: bigint): bigint { const raw = process.env[name]; if (!raw) { return fallback; } try { const n = BigInt(raw); return n > BigInt(0) ? n : fallback; } catch { return fallback; } } const EXPORT_MAX_RUNNING_TASKS = positiveIntFromEnv("EXPORT_MAX_RUNNING_TASKS", 2); const EXPORT_MAX_QUEUED_TASKS = positiveIntFromEnv("EXPORT_MAX_QUEUED_TASKS", 30); const EXPORT_MAX_RETAINED_BYTES = positiveBigIntFromEnv( "EXPORT_MAX_RETAINED_BYTES", BigInt(2) * BigInt(1024) * BigInt(1024) * BigInt(1024), ); function exportRootDir(): string { return path.resolve(process.cwd(), RELATIVE_TMP_DIR, "exports"); } function isPathUnderExportRoot(dir: string): boolean { const root = exportRootDir(); const resolved = path.resolve(dir); return resolved === root || resolved.startsWith(`${root}${path.sep}`); } /** * 允许清理历史 TMP_DIR 下的导出目录: * - 父目录必须是 exports * - 目录名必须形如 export-task- */ function isSafeLegacyTaskOutputDir(dir: string): boolean { const resolved = path.resolve(dir); const parent = path.dirname(resolved); const name = path.basename(resolved); return path.basename(parent) === "exports" && /^export-task-\d+$/.test(name); } async function getExportTaskById(taskId: number) { const [row] = await dbGlobal .select() .from(userExportTasks) .where(eq(userExportTasks.id, taskId)) .limit(1); return row ?? null; } async function getRequiredExportTaskById(taskId: number) { const row = await getExportTaskById(taskId); if (!row) { throw new Error(`export task not found: ${taskId}`); } return row; } export async function createExportTask(params: { userId: number; maskPolicy: ExportMaskPolicy }) { const [activeTask] = await dbGlobal .select({ id: userExportTasks.id }) .from(userExportTasks) .where( and( eq(userExportTasks.userId, params.userId), or(eq(userExportTasks.status, "queued"), eq(userExportTasks.status, "running")), ), ) .limit(1); if (activeTask) { throw { statusCode: 409, statusMessage: "已有导出任务在处理中,请稍后再试" }; } const runningRows = await dbGlobal .select({ runningCount: sql`count(*)` }) .from(userExportTasks) .where(eq(userExportTasks.status, "running")); const runningCount = Number(runningRows[0]?.runningCount ?? 0); if (runningCount >= EXPORT_MAX_RUNNING_TASKS) { throw { statusCode: 429, statusMessage: "导出任务繁忙,请稍后重试", }; } const queuedRows = await dbGlobal .select({ queuedCount: sql`count(*)` }) .from(userExportTasks) .where(eq(userExportTasks.status, "queued")); const queuedCount = Number(queuedRows[0]?.queuedCount ?? 0); if (queuedCount >= EXPORT_MAX_QUEUED_TASKS) { throw { statusCode: 429, statusMessage: "导出排队过多,请稍后重试", }; } const now = new Date(); const retainedRows = await dbGlobal .select({ retainedBytes: sql`coalesce(sum(${userExportTasks.totalBytes}), 0)`, }) .from(userExportTasks) .where( and( eq(userExportTasks.status, "succeeded"), isNotNull(userExportTasks.expiresAt), gt(userExportTasks.expiresAt, now), ), ); const retainedBytes = retainedRows[0]?.retainedBytes ?? "0"; if (BigInt(retainedBytes) >= EXPORT_MAX_RETAINED_BYTES) { throw { statusCode: 507, statusMessage: "导出空间已达上限,请先删除旧导出后重试", }; } const id = await nextIntegerId(userExportTasks, userExportTasks.id); await dbGlobal.insert(userExportTasks).values({ id, userId: params.userId, maskPolicy: params.maskPolicy, status: "queued", }); return getRequiredExportTaskById(id); } export async function listExportTasksByUser(userId: number) { return dbGlobal .select() .from(userExportTasks) .where(eq(userExportTasks.userId, userId)) .orderBy(desc(userExportTasks.id)); } export async function markExportTaskRunning(taskId: number) { const cutoffAt = new Date(); await dbGlobal .update(userExportTasks) .set({ status: "running", exportCutoffAt: cutoffAt, }) .where(and(eq(userExportTasks.id, taskId), eq(userExportTasks.status, "queued"))); const row = await getRequiredExportTaskById(taskId); if (row.status !== "running" || row.exportCutoffAt?.getTime() !== cutoffAt.getTime()) { throw new Error(`invalid export task transition for ${taskId}: expected queued -> running`); } return row; } export async function claimNextQueuedTask() { for (let i = 0; i < 5; i += 1) { const [queued] = await dbGlobal .select({ id: userExportTasks.id }) .from(userExportTasks) .where(eq(userExportTasks.status, "queued")) .orderBy(userExportTasks.id) .limit(1); if (!queued) { return null; } try { return await markExportTaskRunning(queued.id); } catch (error) { const message = error instanceof Error ? error.message : ""; if (!message.includes("invalid export task transition")) { throw error; } } } return null; } export async function markExportTaskSucceeded( taskId: number, payload: { outputDir: string; outputName: string; totalBytes: number; expiresAt: Date; }, ) { await dbGlobal .update(userExportTasks) .set({ status: "succeeded", outputDir: payload.outputDir, outputName: payload.outputName, totalBytes: payload.totalBytes, expiresAt: payload.expiresAt, errorCode: null, errorMessage: null, }) .where(and(eq(userExportTasks.id, taskId), eq(userExportTasks.status, "running"))); const row = await getRequiredExportTaskById(taskId); if (row.status !== "succeeded") { throw new Error(`invalid export task transition for ${taskId}: expected running -> succeeded`); } return row; } export async function markExportTaskFailed( taskId: number, payload: { errorCode: string; errorMessage: string; }, ) { await dbGlobal .update(userExportTasks) .set({ status: "failed", errorCode: payload.errorCode, errorMessage: payload.errorMessage, }) .where(and(eq(userExportTasks.id, taskId), eq(userExportTasks.status, "running"))); const row = await getRequiredExportTaskById(taskId); if (row.status !== "failed") { throw new Error(`invalid export task transition for ${taskId}: expected running -> failed`); } return row; } export async function markExportTaskExpired(taskId: number, message: string) { await dbGlobal .update(userExportTasks) .set({ status: "expired", errorCode: "EXPORT_EXPIRED", errorMessage: message, }) .where(eq(userExportTasks.id, taskId)); return getRequiredExportTaskById(taskId); } export async function getExportTaskForUser(taskId: number, userId: number) { const [row] = await dbGlobal .select() .from(userExportTasks) .where(and(eq(userExportTasks.id, taskId), eq(userExportTasks.userId, userId))) .limit(1); return row ?? null; } export async function deleteExportTaskForUser(taskId: number, userId: number) { const task = await getExportTaskForUser(taskId, userId); if (!task) { throw { statusCode: 404, statusMessage: "导出任务不存在" }; } if (task.status === "running") { throw { statusCode: 409, statusMessage: "任务处理中,暂不可删除" }; } if (task.outputDir && (isPathUnderExportRoot(task.outputDir) || isSafeLegacyTaskOutputDir(task.outputDir))) { try { await fs.rm(task.outputDir, { recursive: true, force: true }); } catch { // ignore fs cleanup failures to keep deletion resilient } } await dbGlobal .delete(userExportTasks) .where(and(eq(userExportTasks.id, taskId), eq(userExportTasks.userId, userId))); }