langxlang
Version:
LLM wrapper for OpenAI GPT and Google Gemini and PaLM 2 models
305 lines (277 loc) • 11.6 kB
JavaScript
const openai = require('./backends/openai')
// const palm2 = require('./backends/palm2')
const gemini = require('./backends/gemini')
const debug = require('debug')('lxl')
const { checkDoesGoogleModelSupportInstructions, checkGuidance } = require('./util')
class BaseCompleteService {
constructor (apiKey) {
this.apiKey = apiKey
}
ok () {
return !!this.apiKey
}
}
class GeminiCompleteService extends BaseCompleteService {
constructor (apiKey) {
super(apiKey || process.env.GEMINI_API_KEY)
}
async requestCompletion (model, text, options, chunkCb) {
// TODO: add support for proper code/text completion models
const messages = [{ role: 'user', text: 'Please complete this text:\n' + text }]
return this.requestChatComplete(model, messages, options, undefined, chunkCb)
}
async listModels () {
return gemini.listModels(this.apiKey)
}
async _processGeminiMessages (model, messages) {
debug('gemini.processMessages', JSON.stringify(messages))
// Google Gemini doesn't support data URLs, or even remote ones, so we need to fetch them, extract data URLs then split
async function resolveImage (url) {
// fetch the URL contents to a data URL (node.js)
const req = await fetch(url)
const buffer = await req.arrayBuffer()
const dataURL = `data:${req.headers.get('content-type')};base64,${Buffer.from(buffer).toString('base64')}`
return dataURL
}
function splitDataURL (entry) {
// gemini doesn't support data URLs
const mimeType = entry.slice(5, entry.indexOf(';'))
const data = entry.slice(entry.indexOf(',') + 1)
return { inlineData: { mimeType, data } }
}
// April 2024 - Only Gemini 1.5 supports instructions
const supportsSystemInstruction = checkDoesGoogleModelSupportInstructions(model)
const imagesForResolve = []
const geminiMessages = messages.map((msg) => {
const m = structuredClone(msg)
if (msg.role === 'assistant') m.role = 'model'
if (msg.role === 'system') m.role = supportsSystemInstruction ? 'system' : 'user'
if (msg.role === 'guidance') m.role = 'model'
if (msg.role === 'function') {
const [part] = msg.parts
m.parts = [{
functionResponse: {
name: part.functionResponse.name,
response: {
name: part.functionResponse.name,
content: part.functionResponse.response
}
}
}]
return m
}
if (msg.text) {
m.parts = [{ text: msg.text }]
delete m.text
return m
}
if (typeof msg.parts === 'object') {
const updated = []
for (const entry of msg.parts) {
if (entry.text) {
updated.push({ text: entry.text })
} else if (entry.imageURL) {
const val = { imageURL: entry.imageURL }
imagesForResolve.push(val)
updated.push(val)
} else if (entry.imageB64Url) {
updated.push(splitDataURL(entry.imageB64Url))
} else if (entry.mimeType) {
const dataAsB64 = Buffer.from(entry.data).toString('base64')
updated.push({ inlineData: { mimeType: entry.mimeType, data: dataAsB64 } })
} else if (entry.functionCall) {
updated.push({ functionCall: { name: entry.functionCall.name, args: entry.functionCall.args } })
}
}
m.parts = updated
} else {
throw new Error('Message .parts should be an array of part objects: ' + JSON.stringify(msg))
}
return m
}).filter((msg) => msg.parts && (msg.parts.length > 0))
for (const entry of imagesForResolve) {
const dataURL = await resolveImage(entry.imageURL)
Object.assign(entry, splitDataURL(dataURL))
delete entry.imageURL
}
return geminiMessages
}
async requestChatComplete (model, messages, { maxTokens, stopSequences, temperature, topP, topK }, functions, chunkCb) {
if (!this.apiKey) throw new Error('Gemini API key not set')
const guidance = checkGuidance(messages, chunkCb)
const geminiMessages = await this._processGeminiMessages(model, messages)
const response = await gemini.generateChatCompletionEx(model, geminiMessages, {
apiKey: this.apiKey,
functions,
generationConfig: {
maxOutputTokens: maxTokens,
stopSequences,
temperature,
topP,
topK
}
}, chunkCb)
if (response.text()) {
const answer = response.text()
chunkCb?.({ done: true, delta: '' })
const content = guidance ? guidance + answer : answer
const result = {
type: 'text',
isTruncated: response.finishReason === 'MAX_TOKENS',
parts: [{ text: content }],
safetyRatings: response.safetyRatings,
text: content
}
return [result]
} else if (response.functionCalls()) {
const calls = response.functionCalls()
const result = {
type: 'function',
fnCalls: calls,
// TODO: map the content parts here to LXL's format
parts: response.parts,
safetyRatings: response.safetyRatings
}
return [result]
} else {
throw new Error('Unknown response from Gemini')
}
}
async requestTranscription (model, audioStream, options) {
throw new Error('Transcription is not supported for Gemini yet - use OpenAI instead')
}
requestSpeechSynthesis (model, text, options) {
throw new Error('Speech synthesis is not supported for Gemini yet - use OpenAI instead')
}
async countTokens (model, content) {
let parts = content
if (!Array.isArray(content)) {
const [a] = await this._processGeminiMessages(model, [{ role: 'user', text: content }])
parts = a.parts
}
return gemini.countTokens(this.apiKey, model, parts)
}
async countTokensInMessages (model, messages) {
return gemini.countTokens(this.geminiApiKey, model, this._processGeminiMessages(model, messages))
}
}
class OpenAICompleteService extends BaseCompleteService {
constructor (apiKey, apiBase) {
super(apiKey || process.env.OPENAI_API_KEY)
this.apiBase = apiBase || process.env.OPENAI_API_BASE
}
async requestCompletion (model, text, options, chunkCb) {
const messages = [{ role: 'user', text: 'Please complete the following text:\n' + text }]
return this.requestChatComplete(model, messages, options, undefined, chunkCb)
}
async requestChatComplete (model, messages, { maxTokens, stopSequences, temperature, topP }, functions, chunkCb) {
if (!this.apiKey) throw new Error('OpenAI API key not set')
const guidance = checkGuidance(messages, chunkCb)
const response = await openai.generateChatCompletionIn(
model,
messages.map((entry) => {
const msg = structuredClone(entry)
if (msg.role === 'model') msg.role = 'assistant'
if (msg.role === 'guidance') msg.role = 'assistant'
if (msg.role === 'function') {
const [part] = msg.parts
msg.role = 'tool'
msg.content = part.functionResponse.response
if (typeof msg.content !== 'string') msg.content = JSON.stringify(msg.content)
msg.tool_call_id = part.functionResponse.id
delete msg.parts
return msg
}
if (msg.text != null) {
delete msg.text
msg.content = entry.text
}
if (typeof msg.parts === 'object') {
const updated = []
for (const key in msg.parts) {
const value = msg.parts[key]
if (value.text) {
if (typeof value.text !== 'string') throw new Error('Expected part.text to be a string: ' + JSON.stringify(value))
updated.push({ type: 'text', text: value.text })
} else if (value.imageURL) {
updated.push({ type: 'image_url', image_url: { url: value.imageURL, detail: value.imageDetail } })
} else if (value.imageB64Url) {
const dataURL = value.imageB64Url
updated.push({ type: 'image_url', image_url: { url: dataURL, detail: value.imageDetail } })
} else if (value.data) {
if (!value.mimeType) throw new Error('Missing mimeType for inline data')
updated.push({ type: 'image_url', image_url: { url: `data:${value.mimeType};base64,${value.data}`, detail: value.imageDetail } })
} else if (value.functionCall) {
msg.tool_calls ??= []
msg.tool_calls.push({ id: value.functionCall.id, type: 'function', function: { name: value.functionCall.name, arguments: JSON.stringify(value.functionCall.args) } })
}
}
msg.content = updated
if (msg.content.every((e) => e.type === 'text')) {
msg.content = msg.content.map((e) => e.text).join('')
}
delete msg.parts
}
return msg
}).filter((msg) => msg.content || msg.tool_calls),
{
baseURL: this.apiBase,
apiKey: this.apiKey,
functions,
generationConfig: {
max_tokens: maxTokens,
stop: stopSequences,
temperature,
top_p: topP
}
},
chunkCb
)
return response.choices.map((choice) => {
const choiceType = {
stop: 'text',
length: 'text',
function_call: 'function',
content_filter: 'safety', // an error would be thrown before this
tool_calls: 'function'
}[choice.finishReason] ?? 'unknown'
const content = guidance ? guidance + choice.content : choice.content
// assert that the content is a string as OpenAI can't interleave image and
// text content yet... and we don't know how it'd look like outputwise if it did
if (typeof content !== 'string') throw new Error('Expected content to be a string')
const parts = [{ text: content }]
return {
type: choiceType,
isTruncated: choice.finishReason === 'length',
// { 0: { id: 'call_n', name: 'Classify', args: '{"choice":"Angry"}' } } => [ { id: 'call_n', name: 'Classify', args: { choice: 'Angry' } } ]
// fnCalls: choice.fnCalls && Object.fromEntries(Object.entries(choice.fnCalls).map(([key, value]) => [key, { id: value.id, name: value.name, args: JSON.parse(value.args) }])),
fnCalls: choice.fnCalls && Object.values(choice.fnCalls).map((value) => ({ id: value.id, name: value.name, args: JSON.parse(value.args) })),
parts,
text: content,
requestUsage: response.usage
}
})
}
async requestTranscription (model, audioStream, options) {
const res = await openai.transcribeAudioEx(this.apiBase, this.apiKey, model, audioStream, options)
return res
}
async requestSpeechSynthesis (model, text, options) {
const res = await openai.synthesizeSpeechEx(this.apiBase, this.apiKey, model, text, options)
return res
}
async listModels () {
const list = await openai.listModels(this.apiBase, this.apiKey)
return Object.fromEntries(list.map((e) => ([e.id, e])))
}
async countTokens (model, content) {
// return openai.countTokens(this.apiKey, model, content)
return require('./tools/tokens').countTokens('gpt-4', content)
}
async countTokensInMessages (model, messages) {
return messages.reduce((cumLen, entry) => {
return cumLen + this.countTokens(model, entry.content)
}, 0)
}
}
module.exports = { OpenAICompleteService, GeminiCompleteService }