@paroicms/site-generator-plugin
Version:
ParoiCMS Site Generator Plugin
157 lines (156 loc) • 5.69 kB
JavaScript
import { isDef, messageOf } from "@paroicms/public-anywhere-lib";
let seq = 0;
export async function batchInvokeMinistral(ctx, prompts, options) {
const startTime = Date.now();
const responses = await execBatchInvokeMinistral(ctx, prompts, options);
const llmMessages = responses
.map((msg) => msg.response.body.choices[0]?.message.content)
.filter(isDef);
const llmReport = {
llmTaskName: options.llmTaskName,
modelName: ctx.mistralModelName,
inputTokenCount: responses
.map((msg) => msg.response.body.usage.prompt_tokens)
.reduce((a, b) => a + b, 0),
durationMs: Date.now() - startTime,
outputTokenCount: responses
.map((msg) => msg.response.body.usage.completion_tokens)
.reduce((a, b) => a + b, 0),
};
return {
llmMessages,
llmReport,
};
}
async function execBatchInvokeMinistral(ctx, prompts, options) {
const { mistral, mistralModelName, logger } = ctx;
const uploadedFileIds = [];
try {
const messages = prompts
.map((prompt, index) => ({
custom_id: `${index}`,
body: {
max_tokens: options.maxTokens,
temperature: options.temperature,
messages: [
{
role: "user",
content: prompt,
},
],
},
}))
.map((request) => JSON.stringify(request))
.join("\n");
const batchData = await mistral.files.upload({
file: {
fileName: `batch-input-${seq++}.jsonl`,
content: Buffer.from(messages),
},
purpose: "batch",
});
uploadedFileIds.push(batchData.id);
const createdJob = await mistral.batch.jobs.create({
inputFiles: [batchData.id],
model: mistralModelName,
endpoint: "/v1/chat/completions",
metadata: { jobType: "batchInvoke" },
timeoutHours: 1,
});
const outputFileId = await waitJobCompletion(ctx, {
jobId: createdJob.id,
timeoutMs: options.timeoutMs,
});
uploadedFileIds.push(outputFileId);
const outputFileStream = await mistral.files.download({ fileId: outputFileId });
const result = await readAsString(outputFileStream);
try {
return result
.trim()
.split("\n")
.map((line) => JSON.parse(line));
}
catch (error) {
logger.error("[Mistral] Error parsing batch job result:", error, result);
throw new Error("Failed to parse batch job result");
}
}
finally {
for (const fileId of uploadedFileIds) {
try {
await mistral.files.delete({ fileId });
}
catch (error) {
logger.error("[Mistral] Error deleting uploaded file:", error, fileId);
}
}
}
}
async function waitJobCompletion(ctx, options) {
const { mistral, logger } = ctx;
const { jobId, timeoutMs } = options;
const startTime = Date.now();
let jobStatus;
let timeoutOccurred = false;
try {
while (true) {
jobStatus = await mistral.batch.jobs.get({ jobId });
const { status } = jobStatus;
if (status === "QUEUED" || status === "RUNNING" || status === "CANCELLATION_REQUESTED") {
const elapsedTime = Date.now() - startTime;
if (elapsedTime > timeoutMs) {
timeoutOccurred = true;
break;
}
await new Promise((resolve) => setTimeout(resolve, 2_000));
continue;
}
if (status === "FAILED" ||
status === "CANCELLED" ||
status === "TIMEOUT_EXCEEDED" ||
status === "SUCCESS") {
break;
}
throw new Error(`Unexpected batch job "${jobStatus.id}" status: "${status}"`);
}
}
catch (error) {
if (!jobStatus) {
throw new Error(`[Mistral] Failed to wait for batch job "${jobId}" completion: ${messageOf(error)}`);
}
logger.error(`[Mistral] Error while waiting for job "${jobId}" completion:`, error);
}
if (!jobStatus)
throw new Error("[Mistral] Should have a job status here");
if (timeoutOccurred) {
logger.debug(`[Mistral] Batch job "${jobId}" timed out after ${timeoutMs}ms. Attempting to cancel…`);
try {
jobStatus = await mistral.batch.jobs.cancel({ jobId });
}
catch (error) {
throw new Error(`[Mistral] Failed to cancel batch job "${jobId}" after timeout: ${messageOf(error)}`);
}
}
const { status, errors } = jobStatus;
if (status !== "SUCCESS") {
const errMessages = errors.map((e) => e.message).join(", ");
throw new Error(`[Mistral] Batch job ${jobStatus.id} failed with status "${status}": ${errMessages}`);
}
if (!jobStatus?.outputFile)
throw new Error("[Mistral] Missing output file");
return jobStatus.outputFile;
}
function readAsString(stream) {
return new Promise((resolve, reject) => {
const output = [];
stream.pipeTo(new WritableStream({
write(chunk) {
output.push(new TextDecoder("utf-8").decode(chunk));
},
close() {
resolve(output.join(""));
},
abort: reject,
}));
});
}