ai
Version:
AI SDK by Vercel - The AI Toolkit for TypeScript and JavaScript
388 lines (343 loc) • 10.5 kB
text/typescript
import type {
Experimental_VideoModelV3,
Experimental_VideoModelV3CallOptions,
Experimental_VideoModelV3File,
SharedV3ProviderMetadata,
} from '@ai-sdk/provider';
import {
convertBase64ToUint8Array,
type DataContent,
type ProviderOptions,
withUserAgentSuffix,
} from '@ai-sdk/provider-utils';
import { NoVideoGeneratedError } from '../error/no-video-generated-error';
import {
DefaultGeneratedFile,
type GeneratedFile,
} from '../generate-text/generated-file';
import { logWarnings } from '../logger/log-warnings';
import { resolveVideoModel } from '../model/resolve-model';
import type { VideoModel } from '../types/video-model';
import type { VideoModelResponseMetadata } from '../types/video-model-response-metadata';
import type { Warning } from '../types/warning';
import {
detectMediaType,
imageMediaTypeSignatures,
videoMediaTypeSignatures,
} from '../util/detect-media-type';
import { download } from '../util/download/download';
import { prepareRetries } from '../util/prepare-retries';
import { VERSION } from '../version';
import type { GenerateVideoResult } from './generate-video-result';
import { splitDataUrl } from '../prompt/split-data-url';
export type GenerateVideoPrompt =
| string
| {
image: DataContent;
text?: string;
};
/**
* Generates videos using a video model.
*
* @param model - The video model to use.
* @param prompt - The prompt that should be used to generate the video.
* @param n - Number of videos to generate. Default: 1.
* @param aspectRatio - Aspect ratio of the videos to generate. Must have the format `{width}:{height}`.
* @param resolution - Resolution of the videos to generate. Must have the format `{width}x{height}`.
* @param duration - Duration of the video in seconds.
* @param fps - Frames per second for the video.
* @param seed - Seed for the video generation.
* @param providerOptions - Additional provider-specific options that are passed through to the provider
* as body parameters.
* @param maxRetries - Maximum number of retries. Set to 0 to disable retries. Default: 2.
* @param abortSignal - An optional abort signal that can be used to cancel the call.
* @param headers - Additional HTTP headers to be sent with the request. Only applicable for HTTP-based providers.
*
* @returns A result object that contains the generated videos.
*/
export async function experimental_generateVideo({
model: modelArg,
prompt: promptArg,
n = 1,
maxVideosPerCall,
aspectRatio,
resolution,
duration,
fps,
seed,
providerOptions,
maxRetries: maxRetriesArg,
abortSignal,
headers,
}: {
/**
* The video model to use.
*/
model: VideoModel;
/**
* The prompt that should be used to generate the video.
*/
prompt: GenerateVideoPrompt;
/**
* Number of videos to generate.
*/
n?: number;
/**
* Maximum number of videos per API call. If not provided, the model's default will be used.
*/
maxVideosPerCall?: number;
/**
* Aspect ratio of the videos to generate. Must have the format `{width}:{height}`.
*/
aspectRatio?: `${number}:${number}`;
/**
* Resolution of the videos to generate. Must have the format `{width}x{height}`.
*/
resolution?: `${number}x${number}`;
/**
* Duration of the video in seconds.
*/
duration?: number;
/**
* Frames per second for the video.
*/
fps?: number;
/**
* Seed for the video generation.
*/
seed?: number;
/**
* Additional provider-specific options that are passed through to the provider
* as body parameters.
*/
providerOptions?: ProviderOptions;
/**
* Maximum number of retries per video model call. Set to 0 to disable retries.
*
* @default 2
*/
maxRetries?: number;
/**
* Abort signal.
*/
abortSignal?: AbortSignal;
/**
* Additional headers to include in the request.
* Only applicable for HTTP-based providers.
*/
headers?: Record<string, string>;
}): Promise<GenerateVideoResult> {
const model = resolveVideoModel(modelArg);
const headersWithUserAgent = withUserAgentSuffix(
headers ?? {},
`ai/${VERSION}`,
);
const { retry } = prepareRetries({
maxRetries: maxRetriesArg,
abortSignal,
});
const { prompt, image } = normalizePrompt(promptArg);
const maxVideosPerCallWithDefault =
maxVideosPerCall ?? (await invokeModelMaxVideosPerCall(model)) ?? 1;
// parallelize calls to the model:
const callCount = Math.ceil(n / maxVideosPerCallWithDefault);
const callVideoCounts = Array.from({ length: callCount }, (_, index) => {
const remaining = n - index * maxVideosPerCallWithDefault;
return Math.min(remaining, maxVideosPerCallWithDefault);
});
const results = await Promise.all(
callVideoCounts.map(async callVideoCount =>
retry(() =>
model.doGenerate({
prompt,
n: callVideoCount,
aspectRatio,
resolution,
duration,
fps,
seed,
image,
providerOptions: providerOptions ?? {},
headers: headersWithUserAgent,
abortSignal,
} satisfies Experimental_VideoModelV3CallOptions),
),
),
);
// collect result videos, warnings, and response metadata
const videos: Array<GeneratedFile> = [];
const warnings: Array<Warning> = [];
const responses: Array<VideoModelResponseMetadata> = [];
const providerMetadata: SharedV3ProviderMetadata = {};
for (const result of results) {
for (const videoData of result.videos) {
switch (videoData.type) {
case 'url': {
const { data, mediaType: downloadedMediaType } = await download({
url: new URL(videoData.url),
});
// Filter out generic/unknown media types that should fall through to detection
const isUsableMediaType = (type: string | undefined): boolean =>
!!type && type !== 'application/octet-stream';
const mediaType =
(isUsableMediaType(videoData.mediaType) && videoData.mediaType) ||
(isUsableMediaType(downloadedMediaType) && downloadedMediaType) ||
detectMediaType({
data,
signatures: videoMediaTypeSignatures,
}) ||
'video/mp4';
videos.push(
new DefaultGeneratedFile({
data,
mediaType,
}),
);
break;
}
case 'base64': {
videos.push(
new DefaultGeneratedFile({
data: videoData.data,
mediaType: videoData.mediaType || 'video/mp4',
}),
);
break;
}
case 'binary': {
const mediaType =
videoData.mediaType ||
detectMediaType({
data: videoData.data,
signatures: videoMediaTypeSignatures,
}) ||
'video/mp4';
videos.push(
new DefaultGeneratedFile({
data: videoData.data,
mediaType,
}),
);
break;
}
}
}
warnings.push(...result.warnings);
responses.push({
timestamp: result.response.timestamp,
modelId: result.response.modelId,
headers: result.response.headers,
providerMetadata: result.providerMetadata,
});
if (result.providerMetadata != null) {
for (const [providerName, metadata] of Object.entries(
result.providerMetadata,
)) {
const existingMetadata = providerMetadata[providerName];
if (existingMetadata != null && typeof existingMetadata === 'object') {
providerMetadata[providerName] = {
...existingMetadata,
...metadata,
};
// Merge videos arrays if both exist
if (
'videos' in existingMetadata &&
Array.isArray(existingMetadata.videos) &&
'videos' in metadata &&
Array.isArray(metadata.videos)
) {
(providerMetadata[providerName] as { videos: unknown[] }).videos = [
...existingMetadata.videos,
...metadata.videos,
];
}
} else {
providerMetadata[providerName] = metadata;
}
}
}
}
if (videos.length === 0) {
throw new NoVideoGeneratedError({ responses });
}
if (warnings.length > 0) {
logWarnings({
warnings,
provider: model.provider,
model: model.modelId,
});
}
return {
video: videos[0],
videos,
warnings,
responses,
providerMetadata,
};
}
function normalizePrompt(promptArg: GenerateVideoPrompt): {
prompt: string | undefined;
image: Experimental_VideoModelV3File | undefined;
} {
if (typeof promptArg === 'string') {
return {
prompt: promptArg,
image: undefined,
};
}
let image: Experimental_VideoModelV3File | undefined;
if (promptArg.image != null) {
const dataContent = promptArg.image;
if (typeof dataContent === 'string') {
if (
dataContent.startsWith('http://') ||
dataContent.startsWith('https://')
) {
image = {
type: 'url',
url: dataContent,
};
} else if (dataContent.startsWith('data:')) {
const { mediaType, base64Content } = splitDataUrl(dataContent);
image = {
type: 'file',
mediaType: mediaType ?? 'image/png',
data: convertBase64ToUint8Array(base64Content ?? ''),
};
} else {
const bytes = convertBase64ToUint8Array(dataContent);
const mediaType =
detectMediaType({
data: bytes,
signatures: imageMediaTypeSignatures,
}) ?? 'image/png';
image = {
type: 'file',
mediaType,
data: bytes,
};
}
} else if (dataContent instanceof Uint8Array) {
const mediaType =
detectMediaType({
data: dataContent,
signatures: imageMediaTypeSignatures,
}) ?? 'image/png';
image = {
type: 'file',
mediaType,
data: dataContent,
};
}
}
return {
prompt: promptArg.text,
image,
};
}
async function invokeModelMaxVideosPerCall(model: Experimental_VideoModelV3) {
if (typeof model.maxVideosPerCall === 'function') {
return await model.maxVideosPerCall({ modelId: model.modelId });
}
return model.maxVideosPerCall;
}