@heroku/plugin-ai
Version:
Heroku CLI plugin for Heroku AI add-on
211 lines (210 loc) • 8.86 kB
JavaScript
Object.defineProperty(exports, "__esModule", { value: true });
const tslib_1 = require("tslib");
const command_1 = require("@heroku-cli/command");
const core_1 = require("@oclif/core");
const node_fs_1 = tslib_1.__importDefault(require("node:fs"));
const base_1 = tslib_1.__importDefault(require("../../../lib/base"));
class Call extends base_1.default {
static args = {
model_resource: core_1.Args.string({
description: 'resource ID or alias of model (--app flag required if alias is used)',
required: true,
}),
};
static description = 'make an inference request to a specific AI model resource ';
static examples = [
'heroku ai:models:call my_llm --app my-app --prompt "What is the meaning of life?" ',
'heroku ai:models:call diffusion --app my-app --prompt "Generate an image of a sunset" --opts \'{"quality":"hd"}\' -o sunset.png ',
];
static flags = {
app: command_1.flags.app({
required: false,
description: 'name or ID of app (required if alias is used)',
}),
// interactive: flags.boolean({
// char: 'i',
// description: 'Use interactive mode for conversation beyond the initial prompt (not available on all models)',
// default: false,
// }),
json: command_1.flags.boolean({ char: 'j', description: 'output response as JSON ' }),
optfile: command_1.flags.string({
description: 'additional options for model inference, provided as a JSON config file ',
required: false,
}),
opts: command_1.flags.string({
description: 'additional options for model inference, provided as a JSON string ',
required: false,
}),
output: command_1.flags.string({
char: 'o',
// description: 'The file path where the command writes the model response. If used with --interactive, this flag writes the entire exchange when the session closes.',
description: 'file path where command writes the model response',
required: false,
}),
prompt: command_1.flags.string({
char: 'p',
description: 'input prompt for model ',
required: false,
}),
remote: command_1.flags.remote(),
};
async run() {
let flags = {};
let args = {};
try {
({ args, flags } = await this.parse(Call));
}
catch (error) {
const { parse: { output } } = error;
({ args, flags } = output);
}
const { model_resource: modelResource } = args;
const { app, json, optfile, opts, output, prompt } = flags;
if (!prompt && !optfile && !opts) {
throw new Error('You must provide either --prompt, --optfile, or --opts.');
}
// Initially, configure the default client to fetch the available model classes
await this.configureHerokuAIClient();
const { body: availableModels } = await this.herokuAI.get('/available-models');
// Now, configure the client to send a request for the target model resource
await this.configureHerokuAIClient(modelResource, app);
const options = this.parseOptions(optfile, opts);
// Not sure why `type` is an array in ModelListItem, we use the type from the first entry.
const modelType = availableModels.find(m => m.model_id === this.apiModelId)?.type[0];
// Note: modelType will always be lower case. MarcusBlankenship 11/13/24.
switch (modelType) {
case 'text-to-embedding': {
const embedding = await this.createEmbedding(prompt, options);
await this.displayEmbedding(embedding, output, json);
break;
}
case 'text-to-image': {
const image = await this.generateImage(prompt, options);
await this.displayImageResult(image, output, json);
break;
}
case 'text-to-text': {
const completion = await this.createChatCompletion(prompt, options);
await this.displayChatCompletion(completion, output, json);
break;
}
default:
throw new Error(`Unsupported model type: ${modelType}`);
}
}
/**
* Parse the model call request options from the command flags.
*
* @param optfile Path to a JSON file containing options.
* @param opts JSON string containing options.
* @returns The parsed options as an object.
*/
parseOptions(optfile, opts) {
const options = {};
if (optfile) {
const optfileContents = node_fs_1.default.readFileSync(optfile);
try {
Object.assign(options, JSON.parse(optfileContents.toString()));
}
catch (error) {
if (error instanceof SyntaxError) {
const { message } = error;
throw new Error(`Invalid JSON in ${optfile}. Check the formatting in your file.\n${message}`);
}
throw error;
}
}
if (opts) {
try {
Object.assign(options, JSON.parse(opts));
}
catch (error) {
if (error instanceof SyntaxError) {
const { message } = error;
throw new Error(`Invalid JSON. Check the formatting in your --opts value.\n${message}`);
}
throw error;
}
}
return options;
}
async createChatCompletion(prompt, options = {}) {
const { prompt: optsPrompt, messages = [], ...rest } = options;
if (prompt) {
messages.push({ role: 'user', content: prompt ?? optsPrompt });
}
const { body: chatCompletionResponse } = await this.herokuAI.post('/v1/chat/completions', {
body: {
...rest,
messages,
model: this.apiModelId,
},
headers: { authorization: `Bearer ${this.apiKey}` },
});
return chatCompletionResponse;
}
async displayChatCompletion(completion, output, json = false) {
const content = completion.choices[0].message.content || '';
if (output) {
node_fs_1.default.writeFileSync(output, json ? JSON.stringify(completion, null, 2) : content);
}
else {
json ? core_1.ux.styledJSON(completion) : core_1.ux.log(content);
}
}
async generateImage(prompt, options = {}) {
const { prompt: optsPrompt, ...rest } = options;
const { body: imageResponse } = await this.herokuAI.post('/v1/images/generations', {
body: {
...rest,
model: this.apiModelId,
prompt: prompt ?? optsPrompt,
},
headers: { authorization: `Bearer ${this.apiKey}` },
});
return imageResponse;
}
async displayImageResult(image, output, json = false) {
if (image.data[0].b64_json) {
if (output) {
const content = json ? JSON.stringify(image, null, 2) : Buffer.from(image.data[0].b64_json, 'base64');
node_fs_1.default.writeFileSync(output, content);
}
else
json ? core_1.ux.styledJSON(image) : process.stdout.write(image.data[0].b64_json);
return;
}
if (image.data[0].url) {
if (output)
node_fs_1.default.writeFileSync(output, json ? JSON.stringify(image, null, 2) : image.data[0].url);
else if (json)
core_1.ux.styledJSON(image);
return;
}
// This should never happen, but we'll handle it anyway
core_1.ux.error('Unexpected response format.', { exit: 1 });
}
async createEmbedding(input, options = {}) {
const { input: optsInput, ...rest } = options;
const { body: EmbeddingResponse } = await this.herokuAI.post('/v1/embeddings', {
body: {
...rest,
model: this.apiModelId,
input: input ?? optsInput,
},
headers: { authorization: `Bearer ${this.apiKey}` },
});
return EmbeddingResponse;
}
async displayEmbedding(embedding, output, json = false) {
const content = (embedding.data[0].embeddings || []).toString();
if (output) {
node_fs_1.default.writeFileSync(output, json ? JSON.stringify(embedding, null, 2) : content);
}
else {
json ? core_1.ux.styledJSON(embedding) : core_1.ux.log(content);
}
}
}
exports.default = Call;
;