@stackbit/utils
Version:
Stackbit utilities
629 lines (574 loc) • 21.8 kB
text/typescript
import EventSource from 'eventsource';
import axios from 'axios';
import { URL } from 'url';
import { Client as NotionClient } from '@notionhq/client';
import { NotionToMarkdown } from 'notion-to-md';
import {
CustomActionModel,
CustomActionRunCommonOptions,
CustomActionModelRunOptions,
CustomActionInputField,
CustomActionProgressFunction,
ModelMap,
Logger
} from '@stackbit/types';
export type GenerateContentFromPresetOptions = {
/**
* The label for the AI action.
* @default Generate content with AI
*/
label?: string;
/**
* The name of the field that contains the page's section.
* For example, 'sections', 'blocks', 'components'.
* Use `modelsConfig.mailListField` when setting root config actions.
*/
mainListField?: string;
/**
* Custom AI prompt.
* Use `modelsConfig.customPrompt` when setting root config actions.
*/
customPrompt?: string;
/**
* Allow content editor to change the custom prompt.
*/
allowOverrideCustomPrompt?: boolean;
/**
* Use `modelsConfig` when setting root config actions, and specify
* `mainListField` and `customPrompt` per model `name`.
*/
modelsConfig?: {
name: string;
mainListField?: string;
customPrompt?: string;
}[];
/**
* Temporary solution to connect the action with site related AI-knowledge.
*/
siteId?: string;
};
function GenerateContentFromPreset({
label,
mainListField,
customPrompt,
allowOverrideCustomPrompt,
modelsConfig,
siteId
}: GenerateContentFromPresetOptions): CustomActionModel {
return {
type: 'model',
name: '_create_from_preset_ai',
label: label ?? 'Generate content with AI',
models: modelsConfig ? modelsConfig.map((model) => model.name) : undefined,
inputFields: [
{
type: 'slug',
name: 'slug',
label: 'Slug'
},
...(siteId
? [
{
type: 'string',
name: '_spark_ai_site_id',
hidden: true,
default: siteId
} satisfies CustomActionInputField
]
: []),
...(allowOverrideCustomPrompt
? [
{
type: 'text',
name: 'customPrompt',
label: 'Custom Prompt',
default: customPrompt
} satisfies CustomActionInputField
]
: [])
],
run: async (options: CustomActionRunCommonOptions & CustomActionModelRunOptions) => {
const logger = options.getLogger();
logger.info(`generate content from preset`);
const schemas = options.getSchemas();
const modelsByName: ModelMap = {};
for (const schema of schemas) {
for (const model of schema.models) {
modelsByName[model.name] = model;
}
}
const sparkClient = new SparkClient({ logger });
const modelConf = (modelsConfig ?? []).find((model) => model.name === options.actionModel.name);
if (modelConf) {
mainListField = modelConf.mainListField ?? mainListField;
customPrompt = modelConf.customPrompt ?? customPrompt;
}
const progressCallback = createProgressCallback(options.progress);
progressCallback({
percent: 0,
categoryMessage: 'Initializing',
progressMessage: 'Initializing'
});
const contentMarkdown = await fetchContent(options, progressCallback, logger);
if (!contentMarkdown || contentMarkdown.length === 0) {
throw new Error('Content must not be empty');
}
const result = await sparkClient.performContentGen({
contentMarkdown,
knowledge: getKnowledgeForInputData(options.inputData),
preset: options.inputData?.presetData,
customPrompt: options.inputData?.customPrompt ?? customPrompt,
modelName: options.actionModel.name,
sectionsField: mainListField,
engineName: 'openai',
modelsByName,
progressCallback
});
if (options.inputData?.slug) {
result.slug = options.inputData?.slug;
}
logger.info('create document');
progressCallback({
percent: 95,
categoryMessage: 'Saving content',
progressMessage: 'Saving content'
});
const { documentId } = await options.contentSourceActions.createDocumentFromObject({
modelName: options.actionModel.name,
object: result
});
logger.info(`created document with ID: '${documentId}'`);
return {
success: 'Successfully generated document',
result: { documentId }
};
}
};
}
class SparkClient {
private readonly logger: Logger;
constructor({ logger }: { logger: any }) {
this.logger = logger;
}
async performContentGen({
contentMarkdown,
knowledge,
preset,
modelName,
sectionsField,
customPrompt,
engineName,
modelsByName,
progressCallback
}: {
contentMarkdown: string;
knowledge?: SparkKnowledge;
preset: any;
modelName: string;
sectionsField?: string;
customPrompt?: string;
engineName?: string;
modelsByName: Record<string, any>;
progressCallback: ProgressCallback;
}) {
const presetData = JSON.stringify(preset);
const modelsByNameData = JSON.stringify(modelsByName);
this.logger.debug('initialize content-gen workload');
const baseUrl = process.env.SPARK_URL ?? 'https://api-create.services.netlify.com/spark';
const initializedWorkloadResult = await axios({
url: `${baseUrl}/api/v1/workload/content-gen`,
method: 'post',
headers: {
'Content-Type': 'application/json',
...(process.env.SPARK_API_KEY
? {
'x-spark-api-key': process.env.SPARK_API_KEY
}
: null)
},
data: {
inputs: {
model: modelName,
engine: engineName ?? 'openai',
sectionsField: sectionsField,
customPrompt: customPrompt
},
uploads: {
createData: {
preset: hashCode(presetData),
modelsByName: hashCode(modelsByNameData)
},
sourceData: {
sourceContent: hashCode(contentMarkdown)
}
},
knowledge
}
}).catch((error) => {
throw new Error(`Failed to initialize AI workload: ${error.message}`);
});
// all of our signed urls are set on this workload result object.
const workloadData = initializedWorkloadResult.data;
this.logger.debug('initialized workload');
if (!workloadData.startWorkload) {
throw new Error('Workload does not contain startWorkload URL');
}
// upload the resources we specified that we would upload on the original init call
this.logger.debug('uploading workload createData');
progressCallback({
percent: 5,
categoryMessage: 'Initializing',
progressMessage: 'Uploading content'
});
await Promise.all([
axios({
url: workloadData.requiredUploads.createData.preset,
method: 'post',
headers: { 'Content-Type': 'application/json' },
data: preset
}),
axios({
url: workloadData.requiredUploads.createData.modelsByName,
method: 'post',
headers: { 'Content-Type': 'application/json' },
data: modelsByName
}),
axios({
url: workloadData.requiredUploads.sourceData.sourceContent,
method: 'post',
headers: { 'Content-Type': 'application/json' },
data: contentMarkdown
})
]).catch((error) => {
throw new Error(`Failed to upload data to AI workload: ${error.message}`);
});
// now that all files are uploaded, kick off the workload
await axios({
url: workloadData.startWorkload,
method: 'post'
}).catch((error) => {
throw new Error(`Failed to start AI workload: ${error.message}`);
});
this.logger.debug('started workload');
// wait until the workload is done
await this.waitUntilDone({
sseUrl: workloadData.progress.sse,
progressCallback: (sparkEvent: SparkEvent) => {
progressCallback(null, sparkEvent);
}
});
progressCallback({
percent: 90,
categoryMessage: 'Transforming content',
progressMessage: 'Content transformation completed'
});
// fetch the result
this.logger.debug('workload finished, fetching workload result');
const workloadResult = await axios({
url: workloadData.workloadResult,
method: 'get'
}).catch((error) => {
throw new Error(`Failed to get AI workload result: ${error.message}`);
});
this.logger.debug('got workload result', {
status: workloadResult.status,
statusText: workloadResult.statusText
});
return workloadResult.data;
}
waitUntilDone({
sseUrl,
progressCallback,
retryCount = 0
}: {
sseUrl: string;
progressCallback?: (sparkEvent: SparkEvent) => void;
retryCount?: number;
}): Promise<void> {
this.logger.debug('subscribe to server-side-events', { retryCount });
const sse = new EventSource(sseUrl);
return new Promise((resolve, reject) => {
sse.onerror = (event) => {
sse.close();
this.logger.error('server-side-event error', { event });
if (retryCount < 3) {
setTimeout(() => {
resolve(
this.waitUntilDone({
sseUrl,
progressCallback,
retryCount: retryCount + 1
})
);
}, 1000);
} else {
this.logger.error('got 3 server-side-event errors, aborting');
reject(new Error('AI Workload failed: could not establish SSE channel'));
}
};
sse.addEventListener('message', (event) => {
this.logger.debug('eventsource event:', { event });
try {
const sparkEvent = JSON.parse(event.data) as SparkEvent;
if (!sparkEvent.latest) {
return;
}
const percentage = sparkEvent.latest.percentage;
if (typeof percentage === 'number' && percentage >= 100) {
sse.close();
if (sparkEvent.latest.systemFailure) {
reject(new Error(`AI Workload failed: ${sparkEvent.latest.errorMessage}`));
} else {
resolve();
}
} else {
progressCallback?.(sparkEvent);
}
} catch (error) {
sse.close();
reject(new Error('AI Workload failed: could not parse Spark event'));
}
});
});
}
}
type SparkEvent = {
latest?: SparkEventProgressMessage;
allProgress?: SparkEventProgressMessage[];
};
type SparkEventProgressMessage = {
percentage?: number;
categoryMessage?: string;
progressMessage?: string;
systemFailure?: boolean;
errorMessage?: string;
};
type ProgressCallback = (localEvent: { percent: number; categoryMessage: string; progressMessage: string } | null, remoteSparkEvent?: SparkEvent) => void;
function createProgressCallback(progressCallback: CustomActionProgressFunction): ProgressCallback {
const localProgressBefore: SparkEventProgressMessage[] = [];
const localProgressAfter: SparkEventProgressMessage[] = [];
let remoteProgress: SparkEventProgressMessage[] = [];
let gotRemoteEvents = false;
const adjustPercentage = (percentage?: number) => {
return 10 + Math.round((percentage ?? 0) * 0.8);
};
return (localEvent: { percent: number; categoryMessage: string; progressMessage: string } | null, remoteSparkEvent?: SparkEvent) => {
let sparkEvent: SparkEvent | undefined;
if (remoteSparkEvent) {
gotRemoteEvents = true;
remoteProgress =
remoteSparkEvent.allProgress?.map((step) => {
return {
...step,
percentage: adjustPercentage(step.percentage)
};
}) ?? remoteProgress;
if (remoteSparkEvent.latest) {
sparkEvent = {
...remoteSparkEvent,
latest: {
...remoteSparkEvent.latest,
percentage: adjustPercentage(remoteSparkEvent.latest.percentage)
},
allProgress: localProgressBefore.concat(remoteProgress ?? [])
};
}
} else if (localEvent) {
const latest = createSparkEventProgressMessage(localEvent);
if (!gotRemoteEvents) {
localProgressBefore.push(latest);
sparkEvent = {
latest,
allProgress: localProgressBefore
};
} else {
localProgressAfter.push(latest);
sparkEvent = {
latest,
allProgress: localProgressBefore.concat(remoteProgress, localProgressAfter)
};
}
}
if (sparkEvent?.latest) {
progressCallback?.({
percent: sparkEvent.latest?.percentage,
message: JSON.stringify(sparkEvent)
});
}
};
}
function createSparkEventProgressMessage(options: {
percent: number;
categoryMessage: string;
progressMessage: string;
systemFailure?: boolean;
errorMessage?: string;
}): SparkEventProgressMessage {
return {
percentage: options.percent,
categoryMessage: options.categoryMessage,
progressMessage: options.progressMessage,
systemFailure: options.systemFailure,
errorMessage: options.errorMessage
};
}
function hashCode(str: string) {
let hash = 0;
for (let i = 0; i < str.length; i++) {
const code = str.charCodeAt(i);
hash = (hash << 5) - hash + code;
hash = hash & hash; // Convert to 32bit integer
}
return hash;
}
type SparkKnowledge = {
scopes: ScopeSelections;
selected: SparkKnowledgeSelection[];
};
type ScopeSelections = {
shared?: 'netlify-known' | 'public';
orgId?: string;
accountId?: string;
siteId?: string;
};
type SparkKnowledgeSelection = {
type: string;
scopeType?: keyof ScopeSelections;
id?: string;
};
function getKnowledgeForInputData(inputData?: Record<string, any>): SparkKnowledge | undefined {
if (!inputData) {
return undefined;
}
if (inputData._spark_ai_knowledge) {
return inputData._spark_ai_knowledge;
}
const selected: SparkKnowledgeSelection[] = [];
if (inputData._spark_ai_voiceTone) {
selected.push({
type: 'voice-and-tone',
id: inputData._spark_ai_voiceTone
});
}
if (inputData._spark_ai_targetAudience) {
selected.push({
type: 'target-audience',
id: inputData._spark_ai_targetAudience
});
}
if (selected.length === 0) {
return undefined;
}
return {
scopes: {
shared: 'netlify-known'
},
selected
};
}
async function fetchContent(options: CustomActionRunCommonOptions, progressCallback: ProgressCallback, logger: Logger): Promise<string> {
const googleDoc = await fetchGoogleDoc(options, progressCallback, logger);
if (googleDoc) {
logger.debug('Got Google doc');
return googleDoc;
}
const notionPage = await fetchNotionPage(options, progressCallback, logger);
if (notionPage) {
logger.debug('Got Notion page');
return notionPage;
}
return options.inputData?._spark_ai_content;
}
async function fetchGoogleDoc(options: CustomActionRunCommonOptions, progressCallback: ProgressCallback, logger: Logger): Promise<string | undefined> {
const googleDocUrl = options.inputData?._spark_ai_googleDocUrl ?? options.inputData?._spark_ai_content;
const match = googleDocUrl?.match(/docs\.google\.com\/document\/d\/([^/?#]+)/);
if (!match) {
return;
}
const googleDocId = match[1];
if (!googleDocId) {
return;
}
logger.debug('Found Google doc URL');
let googleConnection;
if (options.currentUser && 'connections' in options.currentUser && Array.isArray(options.currentUser.connections)) {
googleConnection = options.currentUser.connections.find((connection) => connection.type === 'google');
}
if (!googleConnection || !('accessToken' in googleConnection)) {
logger.debug('User does not have Google connection');
throw new Error('Please connect your Google account');
}
progressCallback({
percent: 2,
categoryMessage: 'Initializing',
progressMessage: `Fetching from ${googleDocUrl}`
});
logger.debug('Fetching Google doc');
// const url = `https://docs.google.com/feeds/download/documents/export/Export?id=${documentId}&exportFormat=markdown`;
const url = `https://www.googleapis.com/drive/v3/files/${googleDocId}/export?mimeType=text/markdown`;
const googleDocResult = await axios({
url: url,
method: 'get',
headers: {
Authorization: `Bearer ${googleConnection.accessToken}`
}
}).catch((error) => {
throw new Error(`Failed to fetch Google doc: ${error.message}`);
});
return googleDocResult.data;
}
async function fetchNotionPage(options: CustomActionRunCommonOptions, progressCallback: ProgressCallback, logger: Logger): Promise<string | undefined> {
// Example for notion page URLs
// https://www.notion.so/{COMPANY_SLUG}/{PAGE_SLUG}-{PAGE_ID}
// https://www.notion.so/{PAGE_ID}
// https://notion.so/{PAGE_SLUG}-{PAGE_ID}
const notionUrl = options.inputData?._spark_ai_notionUrl ?? options.inputData?._spark_ai_content;
if (!/notion\.so\//.test(notionUrl)) {
return undefined;
}
let urlObject: URL;
try {
urlObject = new URL(notionUrl);
} catch (error) {
return undefined;
}
if (!['www.notion.so', 'notion.so'].includes(urlObject.hostname)) {
return undefined;
}
// The notion page ID is encoded into the last section of the URL path after the last hyphen
// https://www.notion.so/Hello-World-12a4dc5adgh230e5ad5ad56456b457d9
// ↳ Notion Page ID ↲
const pathParts = urlObject.pathname.replace(/^\//, '').split('/');
const notionPageSlugAndId = pathParts[pathParts.length - 1];
if (!notionPageSlugAndId) {
return undefined;
}
const pageSlugAndIdParts = notionPageSlugAndId.split('-');
const notionPageId = pageSlugAndIdParts[pageSlugAndIdParts.length - 1];
if (!notionPageId) {
return undefined;
}
logger.debug(`Found Notion page URL with ID: ${notionPageId}`);
let notionConnection;
if (options.currentUser && 'connections' in options.currentUser && Array.isArray(options.currentUser.connections)) {
notionConnection = options.currentUser.connections.find((connection) => connection.type === 'notion');
}
if (!notionConnection || !('accessToken' in notionConnection)) {
logger.debug('User does not have Notion connection');
throw new Error('Please connect your Notion account');
}
progressCallback({
percent: 2,
categoryMessage: 'Initializing',
progressMessage: `Fetching from ${notionUrl}`
});
logger.debug('Fetching Notion page');
const notionClient = new NotionClient({ auth: notionConnection.accessToken });
const n2m = new NotionToMarkdown({ notionClient });
const mdBlocks = await n2m.pageToMarkdown(notionPageId);
const mdString = n2m.toMarkdownString(mdBlocks);
return mdString.parent;
}
export const Actions = {
GenerateContentFromPreset
};