UNPKG

@aj-archipelago/cortex

Version:

Cortex is a GraphQL API for AI. It provides a simple, extensible interface for using AI services from OpenAI, Azure and others.

216 lines (174 loc) 8.01 kB
// openAiWhisperPlugin.js import ModelPlugin from './modelPlugin.js'; import { config } from '../../config.js'; import FormData from 'form-data'; import fs from 'fs'; import { publishRequestProgress } from '../../lib/redisSubscription.js'; import logger from '../../lib/logger.js'; import CortexRequest from '../../lib/cortexRequest.js'; import { downloadFile, deleteTempPath, convertSrtToText, alignSubtitles, getMediaChunks, markCompletedForCleanUp } from '../../lib/util.js'; const OFFSET_CHUNK = 500; //seconds of each chunk offset, only used if helper does not provide class OpenAIWhisperPlugin extends ModelPlugin { constructor(pathway, model) { super(pathway, model); } // Execute the request to the OpenAI Whisper API async execute(text, parameters, prompt, cortexRequest) { const { pathwayResolver } = cortexRequest; const { responseFormat, wordTimestamped, highlightWords, maxLineWidth, maxLineCount, maxWordsPerLine } = parameters; const chunks = []; const processChunk = async (uri) => { try { const cortexRequest = new CortexRequest({ pathwayResolver }); const chunk = await downloadFile(uri); chunks.push(chunk); const { language, responseFormat } = parameters; const { modelPromptText } = this.getCompiledPrompt(text, parameters, prompt); const response_format = responseFormat || 'text'; const whisperInitCallback = (requestInstance) => { const formData = new FormData(); formData.append('file', fs.createReadStream(chunk)); formData.append('model', requestInstance.params.model); formData.append('response_format', response_format); language && formData.append('language', language); modelPromptText && formData.append('prompt', modelPromptText); requestInstance.data = formData; requestInstance.addHeaders = { ...formData.getHeaders() }; }; cortexRequest.initCallback = whisperInitCallback; return this.executeRequest(cortexRequest); } catch (err) { logger.error(`Error getting word timestamped data from api: ${err}`); throw err; } } const processTS = async (uri) => { const tsparams = { fileurl:uri }; const { language } = parameters; if(language) tsparams.language = language; if(highlightWords) tsparams.highlight_words = highlightWords ? "True" : "False"; if(maxLineWidth) tsparams.max_line_width = maxLineWidth; if(maxLineCount) tsparams.max_line_count = maxLineCount; if(maxWordsPerLine) tsparams.max_words_per_line = maxWordsPerLine; tsparams.word_timestamps = !wordTimestamped ? "False" : wordTimestamped; const cortexRequest = new CortexRequest({ pathwayResolver }); const whisperInitCallback = (requestInstance) => { requestInstance.data = tsparams; }; cortexRequest.initCallback = whisperInitCallback; sendProgress(true, true); const res = await this.executeRequest(cortexRequest); if (!res) { throw new Error('Received null or empty response'); } if(res?.statusCode && res?.statusCode >= 400){ throw new Error(res?.message || 'An error occurred.'); } if(!wordTimestamped && !responseFormat){ //if no response format, convert to text if (!res) { logger.warn("Received null or empty response from timestamped API when expecting SRT/VTT format. Returning empty string."); return ""; } return convertSrtToText(res); } return res; } let result = []; let { file } = parameters; let totalCount = 0; let completedCount = 0; let partialCount = 0; const { requestId } = pathwayResolver; let partialRatio = 0; const sendProgress = (partial=false, resetCount=false) => { partialCount = resetCount ? 0 : partialCount; if(partial){ partialCount++; const increment = 0.02 / Math.log2(partialCount + 1); // logarithmic diminishing increment partialRatio = Math.min(partialRatio + increment, 0.99); // limit to 0.99 }else{ partialCount = 0; partialRatio = 0; completedCount++; } if(completedCount >= totalCount) return; const progress = (completedCount + partialRatio) / totalCount; logger.info(`Progress for ${requestId}: ${progress}`); publishRequestProgress({ requestId, progress, data: null, }); } const processURI = async (uri) => { let result = null; let _promise = null; let errorOccurred = false; const intervalId = setInterval(() => sendProgress(true), 3000); // use Timestamped API if model is oai-whisper-ts const useTS = this.modelName === 'oai-whisper-ts'; if (useTS) { _promise = processTS; } else { _promise = processChunk; } await _promise(uri).then((ts) => { result = ts; }).catch((err) => { errorOccurred = err; }).finally(() => { clearInterval(intervalId); sendProgress(); }); if(errorOccurred) { throw errorOccurred; } return result; } let offsets = []; let uris = [] try { const mediaChunks = await getMediaChunks(file, requestId); if (!mediaChunks || !mediaChunks.length) { throw new Error(`Error in getting chunks from media helper for file ${file}`); } uris = mediaChunks.map((chunk) => chunk?.uri || chunk); offsets = mediaChunks.map((chunk, index) => chunk?.offset || index * OFFSET_CHUNK); totalCount = mediaChunks.length + 1; // total number of chunks that will be processed const batchSize = 4; sendProgress(); for (let i = 0; i < uris.length; i += batchSize) { const currentBatchURIs = uris.slice(i, i + batchSize); const promisesToProcess = currentBatchURIs.map(uri => processURI(uri)); const results = await Promise.all(promisesToProcess); for(const res of results) { result.push(res); } } } catch (error) { const errMsg = `Transcribe error: ${error?.response?.data || error?.message || error}`; logger.error(errMsg); return errMsg; } finally { try { for (const chunk of chunks) { try { await deleteTempPath(chunk); } catch (error) { //ignore error } } await markCompletedForCleanUp(requestId); } catch (error) { logger.error(`An error occurred while deleting: ${error}`); } } if (['srt','vtt'].includes(responseFormat) || wordTimestamped) { // align subtitles for formats return alignSubtitles(result, responseFormat, offsets); } return result.join(` `); } } export default OpenAIWhisperPlugin;