UNPKG

@stackbit/utils

Version:
629 lines (574 loc) 21.8 kB
import EventSource from 'eventsource'; import axios from 'axios'; import { URL } from 'url'; import { Client as NotionClient } from '@notionhq/client'; import { NotionToMarkdown } from 'notion-to-md'; import { CustomActionModel, CustomActionRunCommonOptions, CustomActionModelRunOptions, CustomActionInputField, CustomActionProgressFunction, ModelMap, Logger } from '@stackbit/types'; export type GenerateContentFromPresetOptions = { /** * The label for the AI action. * @default Generate content with AI */ label?: string; /** * The name of the field that contains the page's section. * For example, 'sections', 'blocks', 'components'. * Use `modelsConfig.mailListField` when setting root config actions. */ mainListField?: string; /** * Custom AI prompt. * Use `modelsConfig.customPrompt` when setting root config actions. */ customPrompt?: string; /** * Allow content editor to change the custom prompt. */ allowOverrideCustomPrompt?: boolean; /** * Use `modelsConfig` when setting root config actions, and specify * `mainListField` and `customPrompt` per model `name`. */ modelsConfig?: { name: string; mainListField?: string; customPrompt?: string; }[]; /** * Temporary solution to connect the action with site related AI-knowledge. */ siteId?: string; }; function GenerateContentFromPreset({ label, mainListField, customPrompt, allowOverrideCustomPrompt, modelsConfig, siteId }: GenerateContentFromPresetOptions): CustomActionModel { return { type: 'model', name: '_create_from_preset_ai', label: label ?? 'Generate content with AI', models: modelsConfig ? modelsConfig.map((model) => model.name) : undefined, inputFields: [ { type: 'slug', name: 'slug', label: 'Slug' }, ...(siteId ? [ { type: 'string', name: '_spark_ai_site_id', hidden: true, default: siteId } satisfies CustomActionInputField ] : []), ...(allowOverrideCustomPrompt ? [ { type: 'text', name: 'customPrompt', label: 'Custom Prompt', default: customPrompt } satisfies CustomActionInputField ] : []) ], run: async (options: CustomActionRunCommonOptions & CustomActionModelRunOptions) => { const logger = options.getLogger(); logger.info(`generate content from preset`); const schemas = options.getSchemas(); const modelsByName: ModelMap = {}; for (const schema of schemas) { for (const model of schema.models) { modelsByName[model.name] = model; } } const sparkClient = new SparkClient({ logger }); const modelConf = (modelsConfig ?? []).find((model) => model.name === options.actionModel.name); if (modelConf) { mainListField = modelConf.mainListField ?? mainListField; customPrompt = modelConf.customPrompt ?? customPrompt; } const progressCallback = createProgressCallback(options.progress); progressCallback({ percent: 0, categoryMessage: 'Initializing', progressMessage: 'Initializing' }); const contentMarkdown = await fetchContent(options, progressCallback, logger); if (!contentMarkdown || contentMarkdown.length === 0) { throw new Error('Content must not be empty'); } const result = await sparkClient.performContentGen({ contentMarkdown, knowledge: getKnowledgeForInputData(options.inputData), preset: options.inputData?.presetData, customPrompt: options.inputData?.customPrompt ?? customPrompt, modelName: options.actionModel.name, sectionsField: mainListField, engineName: 'openai', modelsByName, progressCallback }); if (options.inputData?.slug) { result.slug = options.inputData?.slug; } logger.info('create document'); progressCallback({ percent: 95, categoryMessage: 'Saving content', progressMessage: 'Saving content' }); const { documentId } = await options.contentSourceActions.createDocumentFromObject({ modelName: options.actionModel.name, object: result }); logger.info(`created document with ID: '${documentId}'`); return { success: 'Successfully generated document', result: { documentId } }; } }; } class SparkClient { private readonly logger: Logger; constructor({ logger }: { logger: any }) { this.logger = logger; } async performContentGen({ contentMarkdown, knowledge, preset, modelName, sectionsField, customPrompt, engineName, modelsByName, progressCallback }: { contentMarkdown: string; knowledge?: SparkKnowledge; preset: any; modelName: string; sectionsField?: string; customPrompt?: string; engineName?: string; modelsByName: Record<string, any>; progressCallback: ProgressCallback; }) { const presetData = JSON.stringify(preset); const modelsByNameData = JSON.stringify(modelsByName); this.logger.debug('initialize content-gen workload'); const baseUrl = process.env.SPARK_URL ?? 'https://api-create.services.netlify.com/spark'; const initializedWorkloadResult = await axios({ url: `${baseUrl}/api/v1/workload/content-gen`, method: 'post', headers: { 'Content-Type': 'application/json', ...(process.env.SPARK_API_KEY ? { 'x-spark-api-key': process.env.SPARK_API_KEY } : null) }, data: { inputs: { model: modelName, engine: engineName ?? 'openai', sectionsField: sectionsField, customPrompt: customPrompt }, uploads: { createData: { preset: hashCode(presetData), modelsByName: hashCode(modelsByNameData) }, sourceData: { sourceContent: hashCode(contentMarkdown) } }, knowledge } }).catch((error) => { throw new Error(`Failed to initialize AI workload: ${error.message}`); }); // all of our signed urls are set on this workload result object. const workloadData = initializedWorkloadResult.data; this.logger.debug('initialized workload'); if (!workloadData.startWorkload) { throw new Error('Workload does not contain startWorkload URL'); } // upload the resources we specified that we would upload on the original init call this.logger.debug('uploading workload createData'); progressCallback({ percent: 5, categoryMessage: 'Initializing', progressMessage: 'Uploading content' }); await Promise.all([ axios({ url: workloadData.requiredUploads.createData.preset, method: 'post', headers: { 'Content-Type': 'application/json' }, data: preset }), axios({ url: workloadData.requiredUploads.createData.modelsByName, method: 'post', headers: { 'Content-Type': 'application/json' }, data: modelsByName }), axios({ url: workloadData.requiredUploads.sourceData.sourceContent, method: 'post', headers: { 'Content-Type': 'application/json' }, data: contentMarkdown }) ]).catch((error) => { throw new Error(`Failed to upload data to AI workload: ${error.message}`); }); // now that all files are uploaded, kick off the workload await axios({ url: workloadData.startWorkload, method: 'post' }).catch((error) => { throw new Error(`Failed to start AI workload: ${error.message}`); }); this.logger.debug('started workload'); // wait until the workload is done await this.waitUntilDone({ sseUrl: workloadData.progress.sse, progressCallback: (sparkEvent: SparkEvent) => { progressCallback(null, sparkEvent); } }); progressCallback({ percent: 90, categoryMessage: 'Transforming content', progressMessage: 'Content transformation completed' }); // fetch the result this.logger.debug('workload finished, fetching workload result'); const workloadResult = await axios({ url: workloadData.workloadResult, method: 'get' }).catch((error) => { throw new Error(`Failed to get AI workload result: ${error.message}`); }); this.logger.debug('got workload result', { status: workloadResult.status, statusText: workloadResult.statusText }); return workloadResult.data; } waitUntilDone({ sseUrl, progressCallback, retryCount = 0 }: { sseUrl: string; progressCallback?: (sparkEvent: SparkEvent) => void; retryCount?: number; }): Promise<void> { this.logger.debug('subscribe to server-side-events', { retryCount }); const sse = new EventSource(sseUrl); return new Promise((resolve, reject) => { sse.onerror = (event) => { sse.close(); this.logger.error('server-side-event error', { event }); if (retryCount < 3) { setTimeout(() => { resolve( this.waitUntilDone({ sseUrl, progressCallback, retryCount: retryCount + 1 }) ); }, 1000); } else { this.logger.error('got 3 server-side-event errors, aborting'); reject(new Error('AI Workload failed: could not establish SSE channel')); } }; sse.addEventListener('message', (event) => { this.logger.debug('eventsource event:', { event }); try { const sparkEvent = JSON.parse(event.data) as SparkEvent; if (!sparkEvent.latest) { return; } const percentage = sparkEvent.latest.percentage; if (typeof percentage === 'number' && percentage >= 100) { sse.close(); if (sparkEvent.latest.systemFailure) { reject(new Error(`AI Workload failed: ${sparkEvent.latest.errorMessage}`)); } else { resolve(); } } else { progressCallback?.(sparkEvent); } } catch (error) { sse.close(); reject(new Error('AI Workload failed: could not parse Spark event')); } }); }); } } type SparkEvent = { latest?: SparkEventProgressMessage; allProgress?: SparkEventProgressMessage[]; }; type SparkEventProgressMessage = { percentage?: number; categoryMessage?: string; progressMessage?: string; systemFailure?: boolean; errorMessage?: string; }; type ProgressCallback = (localEvent: { percent: number; categoryMessage: string; progressMessage: string } | null, remoteSparkEvent?: SparkEvent) => void; function createProgressCallback(progressCallback: CustomActionProgressFunction): ProgressCallback { const localProgressBefore: SparkEventProgressMessage[] = []; const localProgressAfter: SparkEventProgressMessage[] = []; let remoteProgress: SparkEventProgressMessage[] = []; let gotRemoteEvents = false; const adjustPercentage = (percentage?: number) => { return 10 + Math.round((percentage ?? 0) * 0.8); }; return (localEvent: { percent: number; categoryMessage: string; progressMessage: string } | null, remoteSparkEvent?: SparkEvent) => { let sparkEvent: SparkEvent | undefined; if (remoteSparkEvent) { gotRemoteEvents = true; remoteProgress = remoteSparkEvent.allProgress?.map((step) => { return { ...step, percentage: adjustPercentage(step.percentage) }; }) ?? remoteProgress; if (remoteSparkEvent.latest) { sparkEvent = { ...remoteSparkEvent, latest: { ...remoteSparkEvent.latest, percentage: adjustPercentage(remoteSparkEvent.latest.percentage) }, allProgress: localProgressBefore.concat(remoteProgress ?? []) }; } } else if (localEvent) { const latest = createSparkEventProgressMessage(localEvent); if (!gotRemoteEvents) { localProgressBefore.push(latest); sparkEvent = { latest, allProgress: localProgressBefore }; } else { localProgressAfter.push(latest); sparkEvent = { latest, allProgress: localProgressBefore.concat(remoteProgress, localProgressAfter) }; } } if (sparkEvent?.latest) { progressCallback?.({ percent: sparkEvent.latest?.percentage, message: JSON.stringify(sparkEvent) }); } }; } function createSparkEventProgressMessage(options: { percent: number; categoryMessage: string; progressMessage: string; systemFailure?: boolean; errorMessage?: string; }): SparkEventProgressMessage { return { percentage: options.percent, categoryMessage: options.categoryMessage, progressMessage: options.progressMessage, systemFailure: options.systemFailure, errorMessage: options.errorMessage }; } function hashCode(str: string) { let hash = 0; for (let i = 0; i < str.length; i++) { const code = str.charCodeAt(i); hash = (hash << 5) - hash + code; hash = hash & hash; // Convert to 32bit integer } return hash; } type SparkKnowledge = { scopes: ScopeSelections; selected: SparkKnowledgeSelection[]; }; type ScopeSelections = { shared?: 'netlify-known' | 'public'; orgId?: string; accountId?: string; siteId?: string; }; type SparkKnowledgeSelection = { type: string; scopeType?: keyof ScopeSelections; id?: string; }; function getKnowledgeForInputData(inputData?: Record<string, any>): SparkKnowledge | undefined { if (!inputData) { return undefined; } if (inputData._spark_ai_knowledge) { return inputData._spark_ai_knowledge; } const selected: SparkKnowledgeSelection[] = []; if (inputData._spark_ai_voiceTone) { selected.push({ type: 'voice-and-tone', id: inputData._spark_ai_voiceTone }); } if (inputData._spark_ai_targetAudience) { selected.push({ type: 'target-audience', id: inputData._spark_ai_targetAudience }); } if (selected.length === 0) { return undefined; } return { scopes: { shared: 'netlify-known' }, selected }; } async function fetchContent(options: CustomActionRunCommonOptions, progressCallback: ProgressCallback, logger: Logger): Promise<string> { const googleDoc = await fetchGoogleDoc(options, progressCallback, logger); if (googleDoc) { logger.debug('Got Google doc'); return googleDoc; } const notionPage = await fetchNotionPage(options, progressCallback, logger); if (notionPage) { logger.debug('Got Notion page'); return notionPage; } return options.inputData?._spark_ai_content; } async function fetchGoogleDoc(options: CustomActionRunCommonOptions, progressCallback: ProgressCallback, logger: Logger): Promise<string | undefined> { const googleDocUrl = options.inputData?._spark_ai_googleDocUrl ?? options.inputData?._spark_ai_content; const match = googleDocUrl?.match(/docs\.google\.com\/document\/d\/([^/?#]+)/); if (!match) { return; } const googleDocId = match[1]; if (!googleDocId) { return; } logger.debug('Found Google doc URL'); let googleConnection; if (options.currentUser && 'connections' in options.currentUser && Array.isArray(options.currentUser.connections)) { googleConnection = options.currentUser.connections.find((connection) => connection.type === 'google'); } if (!googleConnection || !('accessToken' in googleConnection)) { logger.debug('User does not have Google connection'); throw new Error('Please connect your Google account'); } progressCallback({ percent: 2, categoryMessage: 'Initializing', progressMessage: `Fetching from ${googleDocUrl}` }); logger.debug('Fetching Google doc'); // const url = `https://docs.google.com/feeds/download/documents/export/Export?id=${documentId}&exportFormat=markdown`; const url = `https://www.googleapis.com/drive/v3/files/${googleDocId}/export?mimeType=text/markdown`; const googleDocResult = await axios({ url: url, method: 'get', headers: { Authorization: `Bearer ${googleConnection.accessToken}` } }).catch((error) => { throw new Error(`Failed to fetch Google doc: ${error.message}`); }); return googleDocResult.data; } async function fetchNotionPage(options: CustomActionRunCommonOptions, progressCallback: ProgressCallback, logger: Logger): Promise<string | undefined> { // Example for notion page URLs // https://www.notion.so/{COMPANY_SLUG}/{PAGE_SLUG}-{PAGE_ID} // https://www.notion.so/{PAGE_ID} // https://notion.so/{PAGE_SLUG}-{PAGE_ID} const notionUrl = options.inputData?._spark_ai_notionUrl ?? options.inputData?._spark_ai_content; if (!/notion\.so\//.test(notionUrl)) { return undefined; } let urlObject: URL; try { urlObject = new URL(notionUrl); } catch (error) { return undefined; } if (!['www.notion.so', 'notion.so'].includes(urlObject.hostname)) { return undefined; } // The notion page ID is encoded into the last section of the URL path after the last hyphen // https://www.notion.so/Hello-World-12a4dc5adgh230e5ad5ad56456b457d9 // ↳ Notion Page ID ↲ const pathParts = urlObject.pathname.replace(/^\//, '').split('/'); const notionPageSlugAndId = pathParts[pathParts.length - 1]; if (!notionPageSlugAndId) { return undefined; } const pageSlugAndIdParts = notionPageSlugAndId.split('-'); const notionPageId = pageSlugAndIdParts[pageSlugAndIdParts.length - 1]; if (!notionPageId) { return undefined; } logger.debug(`Found Notion page URL with ID: ${notionPageId}`); let notionConnection; if (options.currentUser && 'connections' in options.currentUser && Array.isArray(options.currentUser.connections)) { notionConnection = options.currentUser.connections.find((connection) => connection.type === 'notion'); } if (!notionConnection || !('accessToken' in notionConnection)) { logger.debug('User does not have Notion connection'); throw new Error('Please connect your Notion account'); } progressCallback({ percent: 2, categoryMessage: 'Initializing', progressMessage: `Fetching from ${notionUrl}` }); logger.debug('Fetching Notion page'); const notionClient = new NotionClient({ auth: notionConnection.accessToken }); const n2m = new NotionToMarkdown({ notionClient }); const mdBlocks = await n2m.pageToMarkdown(notionPageId); const mdString = n2m.toMarkdownString(mdBlocks); return mdString.parent; } export const Actions = { GenerateContentFromPreset };