You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 

285 lines
8.4 KiB

import fs from "node:fs/promises";
import path from "node:path";
import { dbGlobal } from "drizzle-pkg/lib/db";
import { userExportTasks } from "drizzle-pkg/lib/schema/export";
import { and, desc, eq, or, sql } from "drizzle-orm";
import { nextIntegerId } from "../../utils/sqlite-id";
import { RELATIVE_TMP_DIR } from "#server/constants/media";
type ExportMaskPolicy = "masked" | "raw";
function positiveIntFromEnv(name: string, fallback: number): number {
const raw = process.env[name];
if (!raw) {
return fallback;
}
const n = Number(raw);
return Number.isInteger(n) && n > 0 ? n : fallback;
}
function positiveBigIntFromEnv(name: string, fallback: bigint): bigint {
const raw = process.env[name];
if (!raw) {
return fallback;
}
try {
const n = BigInt(raw);
return n > 0n ? n : fallback;
} catch {
return fallback;
}
}
const EXPORT_MAX_RUNNING_TASKS = positiveIntFromEnv("EXPORT_MAX_RUNNING_TASKS", 2);
const EXPORT_MAX_QUEUED_TASKS = positiveIntFromEnv("EXPORT_MAX_QUEUED_TASKS", 30);
const EXPORT_MAX_RETAINED_BYTES = positiveBigIntFromEnv("EXPORT_MAX_RETAINED_BYTES", 2n * 1024n * 1024n * 1024n);
function exportRootDir(): string {
return path.resolve(process.cwd(), RELATIVE_TMP_DIR, "exports");
}
function isPathUnderExportRoot(dir: string): boolean {
const root = exportRootDir();
const resolved = path.resolve(dir);
return resolved === root || resolved.startsWith(`${root}${path.sep}`);
}
/**
* 允许清理历史 TMP_DIR 下的导出目录:
* - 父目录必须是 exports
* - 目录名必须形如 export-task-<id>
*/
function isSafeLegacyTaskOutputDir(dir: string): boolean {
const resolved = path.resolve(dir);
const parent = path.dirname(resolved);
const name = path.basename(resolved);
return path.basename(parent) === "exports" && /^export-task-\d+$/.test(name);
}
async function getExportTaskById(taskId: number) {
const [row] = await dbGlobal
.select()
.from(userExportTasks)
.where(eq(userExportTasks.id, taskId))
.limit(1);
return row ?? null;
}
async function getRequiredExportTaskById(taskId: number) {
const row = await getExportTaskById(taskId);
if (!row) {
throw new Error(`export task not found: ${taskId}`);
}
return row;
}
export async function createExportTask(params: { userId: number; maskPolicy: ExportMaskPolicy }) {
const [activeTask] = await dbGlobal
.select({ id: userExportTasks.id })
.from(userExportTasks)
.where(
and(
eq(userExportTasks.userId, params.userId),
or(eq(userExportTasks.status, "queued"), eq(userExportTasks.status, "running")),
),
)
.limit(1);
if (activeTask) {
throw { statusCode: 409, statusMessage: "已有导出任务在处理中,请稍后再试" };
}
const runningRows = await dbGlobal
.select({ runningCount: sql<number>`count(*)` })
.from(userExportTasks)
.where(eq(userExportTasks.status, "running"));
const runningCount = Number(runningRows[0]?.runningCount ?? 0);
if (runningCount >= EXPORT_MAX_RUNNING_TASKS) {
throw {
statusCode: 429,
statusMessage: "导出任务繁忙,请稍后重试",
};
}
const queuedRows = await dbGlobal
.select({ queuedCount: sql<number>`count(*)` })
.from(userExportTasks)
.where(eq(userExportTasks.status, "queued"));
const queuedCount = Number(queuedRows[0]?.queuedCount ?? 0);
if (queuedCount >= EXPORT_MAX_QUEUED_TASKS) {
throw {
statusCode: 429,
statusMessage: "导出排队过多,请稍后重试",
};
}
const now = new Date();
const retainedRows = await dbGlobal
.select({
retainedBytes: sql<string>`coalesce(sum(${userExportTasks.totalBytes}), 0)`,
})
.from(userExportTasks)
.where(
and(
eq(userExportTasks.status, "succeeded"),
sql`${userExportTasks.expiresAt} IS NOT NULL`,
sql`${userExportTasks.expiresAt} > ${now}`,
),
);
const retainedBytes = retainedRows[0]?.retainedBytes ?? "0";
if (BigInt(retainedBytes) >= EXPORT_MAX_RETAINED_BYTES) {
throw {
statusCode: 507,
statusMessage: "导出空间已达上限,请先删除旧导出后重试",
};
}
const id = await nextIntegerId(userExportTasks, userExportTasks.id);
await dbGlobal.insert(userExportTasks).values({
id,
userId: params.userId,
maskPolicy: params.maskPolicy,
status: "queued",
});
return getRequiredExportTaskById(id);
}
export async function listExportTasksByUser(userId: number) {
return dbGlobal
.select()
.from(userExportTasks)
.where(eq(userExportTasks.userId, userId))
.orderBy(desc(userExportTasks.id));
}
export async function markExportTaskRunning(taskId: number) {
const cutoffAt = new Date();
await dbGlobal
.update(userExportTasks)
.set({
status: "running",
exportCutoffAt: cutoffAt,
})
.where(and(eq(userExportTasks.id, taskId), eq(userExportTasks.status, "queued")));
const row = await getRequiredExportTaskById(taskId);
if (row.status !== "running" || row.exportCutoffAt?.getTime() !== cutoffAt.getTime()) {
throw new Error(`invalid export task transition for ${taskId}: expected queued -> running`);
}
return row;
}
export async function claimNextQueuedTask() {
for (let i = 0; i < 5; i += 1) {
const [queued] = await dbGlobal
.select({ id: userExportTasks.id })
.from(userExportTasks)
.where(eq(userExportTasks.status, "queued"))
.orderBy(userExportTasks.id)
.limit(1);
if (!queued) {
return null;
}
try {
return await markExportTaskRunning(queued.id);
} catch (error) {
const message = error instanceof Error ? error.message : "";
if (!message.includes("invalid export task transition")) {
throw error;
}
}
}
return null;
}
export async function markExportTaskSucceeded(
taskId: number,
payload: {
outputDir: string;
outputName: string;
totalBytes: number;
expiresAt: Date;
},
) {
await dbGlobal
.update(userExportTasks)
.set({
status: "succeeded",
outputDir: payload.outputDir,
outputName: payload.outputName,
totalBytes: payload.totalBytes,
expiresAt: payload.expiresAt,
errorCode: null,
errorMessage: null,
})
.where(and(eq(userExportTasks.id, taskId), eq(userExportTasks.status, "running")));
const row = await getRequiredExportTaskById(taskId);
if (row.status !== "succeeded") {
throw new Error(`invalid export task transition for ${taskId}: expected running -> succeeded`);
}
return row;
}
export async function markExportTaskFailed(
taskId: number,
payload: {
errorCode: string;
errorMessage: string;
},
) {
await dbGlobal
.update(userExportTasks)
.set({
status: "failed",
errorCode: payload.errorCode,
errorMessage: payload.errorMessage,
})
.where(and(eq(userExportTasks.id, taskId), eq(userExportTasks.status, "running")));
const row = await getRequiredExportTaskById(taskId);
if (row.status !== "failed") {
throw new Error(`invalid export task transition for ${taskId}: expected running -> failed`);
}
return row;
}
export async function markExportTaskExpired(taskId: number, message: string) {
await dbGlobal
.update(userExportTasks)
.set({
status: "expired",
errorCode: "EXPORT_EXPIRED",
errorMessage: message,
})
.where(eq(userExportTasks.id, taskId));
return getRequiredExportTaskById(taskId);
}
export async function getExportTaskForUser(taskId: number, userId: number) {
const [row] = await dbGlobal
.select()
.from(userExportTasks)
.where(and(eq(userExportTasks.id, taskId), eq(userExportTasks.userId, userId)))
.limit(1);
return row ?? null;
}
export async function deleteExportTaskForUser(taskId: number, userId: number) {
const task = await getExportTaskForUser(taskId, userId);
if (!task) {
throw { statusCode: 404, statusMessage: "导出任务不存在" };
}
if (task.status === "running") {
throw { statusCode: 409, statusMessage: "任务处理中,暂不可删除" };
}
if (task.outputDir && (isPathUnderExportRoot(task.outputDir) || isSafeLegacyTaskOutputDir(task.outputDir))) {
try {
await fs.rm(task.outputDir, { recursive: true, force: true });
} catch {
// ignore fs cleanup failures to keep deletion resilient
}
}
await dbGlobal
.delete(userExportTasks)
.where(and(eq(userExportTasks.id, taskId), eq(userExportTasks.userId, userId)));
}