UNPKG

@heroku/plugin-ai

Version:
211 lines (210 loc) 8.86 kB
"use strict"; Object.defineProperty(exports, "__esModule", { value: true }); const tslib_1 = require("tslib"); const command_1 = require("@heroku-cli/command"); const core_1 = require("@oclif/core"); const node_fs_1 = tslib_1.__importDefault(require("node:fs")); const base_1 = tslib_1.__importDefault(require("../../../lib/base")); class Call extends base_1.default { static args = { model_resource: core_1.Args.string({ description: 'resource ID or alias of model (--app flag required if alias is used)', required: true, }), }; static description = 'make an inference request to a specific AI model resource '; static examples = [ 'heroku ai:models:call my_llm --app my-app --prompt "What is the meaning of life?" ', 'heroku ai:models:call diffusion --app my-app --prompt "Generate an image of a sunset" --opts \'{"quality":"hd"}\' -o sunset.png ', ]; static flags = { app: command_1.flags.app({ required: false, description: 'name or ID of app (required if alias is used)', }), // interactive: flags.boolean({ // char: 'i', // description: 'Use interactive mode for conversation beyond the initial prompt (not available on all models)', // default: false, // }), json: command_1.flags.boolean({ char: 'j', description: 'output response as JSON ' }), optfile: command_1.flags.string({ description: 'additional options for model inference, provided as a JSON config file ', required: false, }), opts: command_1.flags.string({ description: 'additional options for model inference, provided as a JSON string ', required: false, }), output: command_1.flags.string({ char: 'o', // description: 'The file path where the command writes the model response. If used with --interactive, this flag writes the entire exchange when the session closes.', description: 'file path where command writes the model response', required: false, }), prompt: command_1.flags.string({ char: 'p', description: 'input prompt for model ', required: false, }), remote: command_1.flags.remote(), }; async run() { let flags = {}; let args = {}; try { ({ args, flags } = await this.parse(Call)); } catch (error) { const { parse: { output } } = error; ({ args, flags } = output); } const { model_resource: modelResource } = args; const { app, json, optfile, opts, output, prompt } = flags; if (!prompt && !optfile && !opts) { throw new Error('You must provide either --prompt, --optfile, or --opts.'); } // Initially, configure the default client to fetch the available model classes await this.configureHerokuAIClient(); const { body: availableModels } = await this.herokuAI.get('/available-models'); // Now, configure the client to send a request for the target model resource await this.configureHerokuAIClient(modelResource, app); const options = this.parseOptions(optfile, opts); // Not sure why `type` is an array in ModelListItem, we use the type from the first entry. const modelType = availableModels.find(m => m.model_id === this.apiModelId)?.type[0]; // Note: modelType will always be lower case. MarcusBlankenship 11/13/24. switch (modelType) { case 'text-to-embedding': { const embedding = await this.createEmbedding(prompt, options); await this.displayEmbedding(embedding, output, json); break; } case 'text-to-image': { const image = await this.generateImage(prompt, options); await this.displayImageResult(image, output, json); break; } case 'text-to-text': { const completion = await this.createChatCompletion(prompt, options); await this.displayChatCompletion(completion, output, json); break; } default: throw new Error(`Unsupported model type: ${modelType}`); } } /** * Parse the model call request options from the command flags. * * @param optfile Path to a JSON file containing options. * @param opts JSON string containing options. * @returns The parsed options as an object. */ parseOptions(optfile, opts) { const options = {}; if (optfile) { const optfileContents = node_fs_1.default.readFileSync(optfile); try { Object.assign(options, JSON.parse(optfileContents.toString())); } catch (error) { if (error instanceof SyntaxError) { const { message } = error; throw new Error(`Invalid JSON in ${optfile}. Check the formatting in your file.\n${message}`); } throw error; } } if (opts) { try { Object.assign(options, JSON.parse(opts)); } catch (error) { if (error instanceof SyntaxError) { const { message } = error; throw new Error(`Invalid JSON. Check the formatting in your --opts value.\n${message}`); } throw error; } } return options; } async createChatCompletion(prompt, options = {}) { const { prompt: optsPrompt, messages = [], ...rest } = options; if (prompt) { messages.push({ role: 'user', content: prompt ?? optsPrompt }); } const { body: chatCompletionResponse } = await this.herokuAI.post('/v1/chat/completions', { body: { ...rest, messages, model: this.apiModelId, }, headers: { authorization: `Bearer ${this.apiKey}` }, }); return chatCompletionResponse; } async displayChatCompletion(completion, output, json = false) { const content = completion.choices[0].message.content || ''; if (output) { node_fs_1.default.writeFileSync(output, json ? JSON.stringify(completion, null, 2) : content); } else { json ? core_1.ux.styledJSON(completion) : core_1.ux.log(content); } } async generateImage(prompt, options = {}) { const { prompt: optsPrompt, ...rest } = options; const { body: imageResponse } = await this.herokuAI.post('/v1/images/generations', { body: { ...rest, model: this.apiModelId, prompt: prompt ?? optsPrompt, }, headers: { authorization: `Bearer ${this.apiKey}` }, }); return imageResponse; } async displayImageResult(image, output, json = false) { if (image.data[0].b64_json) { if (output) { const content = json ? JSON.stringify(image, null, 2) : Buffer.from(image.data[0].b64_json, 'base64'); node_fs_1.default.writeFileSync(output, content); } else json ? core_1.ux.styledJSON(image) : process.stdout.write(image.data[0].b64_json); return; } if (image.data[0].url) { if (output) node_fs_1.default.writeFileSync(output, json ? JSON.stringify(image, null, 2) : image.data[0].url); else if (json) core_1.ux.styledJSON(image); return; } // This should never happen, but we'll handle it anyway core_1.ux.error('Unexpected response format.', { exit: 1 }); } async createEmbedding(input, options = {}) { const { input: optsInput, ...rest } = options; const { body: EmbeddingResponse } = await this.herokuAI.post('/v1/embeddings', { body: { ...rest, model: this.apiModelId, input: input ?? optsInput, }, headers: { authorization: `Bearer ${this.apiKey}` }, }); return EmbeddingResponse; } async displayEmbedding(embedding, output, json = false) { const content = (embedding.data[0].embeddings || []).toString(); if (output) { node_fs_1.default.writeFileSync(output, json ? JSON.stringify(embedding, null, 2) : content); } else { json ? core_1.ux.styledJSON(embedding) : core_1.ux.log(content); } } } exports.default = Call;