gpt3rocket
Version:
Little helper utility for priming + transacting w/ GPT3 api
356 lines (330 loc) • 12.1 kB
text/typescript
import axios from "axios";
export interface APIResponseChoice {
text: string;
index: number;
logprobs?: number;
finish_reason: string;
}
export interface APIResponse {
id: string;
object: string;
created: number;
model: string;
choices: APIResponseChoice[];
}
export type Sample = string[][];
export type Samples = [string, string] | [string, string][] | [] | string[][];
export interface APIFlags {
engine?: string; // The engine model id-- there 4 models, ada, babbage, curie, davinci)
prompt?: string; // One or more prompts to generate from. Can be a string, list of strings, a list of integers (i.e. a single prompt encoded as tokens), or list of lists of integers (i.e. many prompts encoded as integers).
max_tokens?: number; // How many tokens to complete to. Can return fewer if a stop sequence is hit.
temperature?: number; // What sampling temperature to use. Higher values means the model will take more risks. Try 0.9 for more creative applications, and 0 (argmax sampling) for ones with a well-defined answer. We generally recommend using this or top_p but not both.
top_p?: number; // An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. We generally recommend using this or temperature but not both.
n?: number; // How many choices to create for each prompt.
stream?: boolean; // Whether to stream back partial progress. If set, tokens will be sent as data-only server-sent events as they become available, with the stream terminated by a data: [DONE] message.
logprobs?: number; // Include the log probabilities on the logprobs most likely tokens. So for example, if logprobs is 10, the API will return a list of the 10 most likely tokens. If logprobs is supplied, the API will always return the logprob of the sampled token, so there may be up to logprobs+1 elements in the response.
stop?: string; // One or more sequences where the API will stop generating further tokens. The returned text will not contain the stop sequence.
[key: string]: any;
}
export interface APIConfig {
full_response?: boolean; // Defaults to false, if set to true .ask() will return full response w/ metadata
endpoint?: string; // Defaults to 'https://api.openai.com/v1/engines/davinci/completions' (can change engine in APIConfig)
}
export interface RootConfig {
samples?: Samples; // Training samples
prefix?: string; // Top-line prefix to "set the table" of the API interaction
credential: string; // sensitive key
APIConfig?: APIConfig;
APIFlags?: APIFlags;
transform?: any;
inputString?: string; // Label for samples, defaults to "input"
outputString?: string; // Label for samples, defaults to "output"
debug?: boolean;
}
/**
* ## Opts: Samples & prefix
* Samples & a prefix string will prime your agent
*
* ### opts.samples (optional)
*
* *array of Samples*
*
* ```ts
* const samples = [['marco', 'polo'], ['marrrrrccccoo', 'pollllooo']
* ```
* ### opts.prefix (optional)
* String to prepend to top of message as "primer"
*
* *string*
*
*```ts
* const prefix = 'The following is an exchange between the user and an intelligent agent. The agent is friendly, prompt, and wants to help the
* ```
*
## Transform (optional)
* An optional function to adjust how the prefix & samples are structured when sent to API
*
* Receives samples, prefix, inputString, outputString
* Without a custom function, a template will look like the following
*
* ```
* Prefix phrase ____
* input: aaa
* output: bbb
* input: ${user_prompt_goes_here}
*```
*
*
* ```ts
* const transform = ({samples, prefix, inputString, outputString} => {
* const decoratedSamples = samples.map((example, idx) => {
* if (!(idx % 2)) {
* return `${inputString}:${example}`;
* } else {
* return `${outputString}:${example}`;
* }
* });
*
* return `\n${prefix}\n${decoratedSamples.join("\n")}`;
*
* })
*
* ```
*
* ## APIConfig
* ```
* engine:string; // The engine ID, defaults to davinci (ada, babbage, curie, davinci)
* prompt?:string; //One or more prompts to generate from. Can be a string, list of strings, a list of integers (i.e. a single prompt encoded as tokens), or list of lists of integers (i.e. many prompts encoded as integers).
* max_tokens?:number; //How many tokens to complete to. Can return fewer if a stop sequence is hit.
* temperature?:number; //What sampling temperature to use. Higher values means the model will take more risks. Try 0.9 for more creative applications, and 0 (argmax sampling) for ones with a well-defined answer. We generally recommend using this or top_p but not both.
* top_p?:number; //An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. We generally recommend using this or temperature but not both.
* n?:number; //How many choices to create for each prompt.
* stream?:boolean; //Whether to stream back partial progress. If set, tokens will be sent as data-only server-sent events as they become available, with the stream terminated by a data: [DONE] message.
* logprobs?:integer; //Include the log probabilities on the logprobs most likely tokens. So for example, if logprobs is 10, the API will return a list of the 10 most likely tokens. If logprobs is supplied, the API will always return the logprob of the sampled token, so there may be up to logprobs+1 elements in the response.
* stop?:string; //One or more sequences where the API will stop generating further tokens. The returned text will not contain the stop sequence. *
* ```
*/
export class GPT3Rocket {
public config: RootConfig;
constructor(configRef: RootConfig) {
// ["What are you?", "I am a helper agent here to answer your questions!"],
const defaults = {
samples: [],
prefix: `This is a conversation with a helpful agent. The agent is kind, clever and eager to help`,
transform: this._transformer,
inputString: "input",
outputString: "output",
credential: "______________________-",
APIConfig: {
endpoint: "https://api.openai.com/v1/engines/davinci/completions",
full_response: false,
},
APIFlags: {
max_tokens: 20,
temperature: 0.3,
stop: "\n",
},
debug: false,
};
const mergeAPIFlags = Object.assign(
defaults.APIFlags,
configRef.APIFlags || {}
);
this.config = Object.assign(defaults, configRef, {
APIFlags: mergeAPIFlags,
});
this.__debug("<gpt3-rocket> Root config:", this.config);
}
buildQuery(prompt: string, samples: Samples = [], prefix: string = "") {
let prefixRef = this.config.prefix || "";
if (prefix) {
prefixRef = prefix;
}
let sampleRef = this.config.samples || [];
if (samples && samples.length) {
// Q: merge samples?
sampleRef = samples;
}
if (typeof this.config.transform === "function") {
return this.config.transform(
prompt,
sampleRef,
prefixRef as string,
this.config.inputString,
this.config.outputString
);
} else {
return this._transformer(
prompt,
sampleRef,
prefixRef as string,
this.config.inputString as string,
this.config.outputString as string
);
}
}
async ask(
prompt: string,
samples: Samples = [],
prefix: string = "",
APIFlags: APIFlags = {},
APIConfig: APIConfig = {}
): Promise<any> {
let query = prompt;
if (samples && samples.length) {
if (prefix) {
query = this.buildQuery(prompt, samples, prefix);
} else if (this.config.prefix) {
query = this.buildQuery(prompt, samples, this.config.prefix);
}
} else {
if (this.config.samples) {
query = this.buildQuery(
prompt,
this.config.samples,
this.config.prefix
);
} else {
query = this.buildQuery(
prompt,
this.config.samples,
this.config.prefix
);
}
}
// Plaintext, config fallback case
const mergedAPIConfig = Object.assign(this.config.APIConfig, APIConfig);
const mergedAPIFlags = Object.assign(this.config.APIFlags, APIFlags);
this.__debug("<gpt3-rocket> Query: ", query);
const endpoint = mergedAPIConfig.endpoint;
let error: boolean | any = false;
const result = await axios
.post(
endpoint as string,
{
prompt: query,
...mergedAPIFlags,
},
{
headers: {
"Content-Type": "application/json",
Authorization: `Bearer ${this.config.credential}`,
},
}
)
.catch((e) => {
this.__debug("<gpt3-rocket> ERROR:", e.response);
if (e.response && e.response.status === 401 && e.response.data) {
console.log(`\n\n<YOUR CREDENTIAL IS LIKELY INVALID>\n\n`);
if (e.response.data.error) {
error = e.response.data.error;
}
}
});
const { full_response } = mergedAPIConfig;
if (!error) {
if (full_response && result) {
return result.data;
} else if (result) {
const res = result.data.choices[0].text || "";
const target = `${this.config.outputString}:`; // ex output:
return { text: res.replace(target, "") };
}
} else {
if (full_response) {
return error;
} else {
return {
text:
error.message || "There was a problem (your key might be invalid)",
};
}
}
}
add(sample: [string, string]) {
if (sample.length > 2 || sample.length < 2 || !Array.isArray(sample)) {
throw new Error("Sample should be exactly one input & one output");
}
//@ts-ignorex
this?.config?.samples?.push(sample);
}
addPrefix(prefix: string) {
this.config.prefix = prefix;
}
changeTransformer(
transformerFunction: (
prompt: string,
samples: Samples,
prefix: string,
inputString: string,
outputString: string
) => string
) {
if (typeof transformerFunction === "function") {
this.config.transform = transformerFunction;
}
}
resetTransformer() {
this.config.transform = this._transformer;
}
clear() {
this.clearSamples();
this.clearPrefix();
}
clearSamples() {
this.config.samples = [];
}
clearPrefix() {
this.config.prefix = "";
}
updateCredential(credential: string) {
this.config.credential = credential;
}
__debug(...payload: any) {
if (this.config.debug) {
console.log.apply(console, payload);
}
}
_transformer(
prompt: string,
samples: Samples,
prefix: string,
inputString: string,
outputString: string
) {
//@ts-ignore
const decoratedSamples = [].concat(...samples).map((example, idx) => {
if (!(idx % 2)) {
return `${inputString}:${example}`;
} else {
return `${outputString}:${example}`;
}
});
if (prefix && decoratedSamples.length) {
return `${prefix}\n${decoratedSamples.join(
"\n"
)}\n${inputString}:${prompt}\n`;
} else {
return `${inputString}:${prompt}\n`;
}
}
}
/**
* ENDPOINT
*
*/
export const gpt3Endpoint = (config: RootConfig) => {
const inst = new GPT3Rocket(config);
// TODO: req/res types, body-parser/no body-parser
return async (req: any, res: any, next: any) => {
const {
samples = [],
prefix = "",
APIConfig = {},
APIFlags = {},
prompt,
} = req.body;
const result = await inst.ask(prompt, samples, prefix, APIFlags, APIConfig);
return res.status(200).send(result);
};
};