UNPKG

@jackhua/mini-langchain

Version:

A lightweight TypeScript implementation of LangChain with cost optimization features

238 lines 9.24 kB
"use strict"; var __importDefault = (this && this.__importDefault) || function (mod) { return (mod && mod.__esModule) ? mod : { "default": mod }; }; Object.defineProperty(exports, "__esModule", { value: true }); exports.Gemini = void 0; exports.createGeminiFromEnv = createGeminiFromEnv; const axios_1 = __importDefault(require("axios")); const base_1 = require("./base"); class Gemini extends base_1.BaseChatLLM { /** * Get the identifying parameters of the LLM */ get identifyingParams() { return { model: this.model, temperature: this.defaultTemperature, maxTokens: this.defaultMaxTokens, topP: this.defaultTopP, topK: this.defaultTopK }; } /** * Get the type of LLM */ get llmType() { return 'gemini'; } constructor(config) { super(); this.apiKey = config.apiKey; this.model = config.model || 'gemini-1.5-flash'; this.defaultTemperature = config.defaultTemperature ?? 0.9; this.defaultMaxTokens = config.defaultMaxTokens ?? 2048; this.defaultTopP = config.defaultTopP; this.defaultTopK = config.defaultTopK; const baseURL = config.baseURL || 'https://generativelanguage.googleapis.com/v1beta'; this.client = axios_1.default.create({ baseURL, headers: { 'Content-Type': 'application/json', } }); } /** * Format messages for Gemini API */ formatMessages(messages) { const formattedMessages = []; for (const message of messages) { let role; switch (message.type) { case 'human': role = 'user'; break; case 'ai': role = 'model'; break; case 'system': // Gemini doesn't have system messages, prepend to first user message if (formattedMessages.length === 0 || formattedMessages[0].role !== 'user') { formattedMessages.unshift({ role: 'user', parts: [{ text: message.content }] }); } else { formattedMessages[0].parts.unshift({ text: message.content + '\n\n' }); } continue; default: role = 'user'; } formattedMessages.push({ role, parts: [{ text: message.content }] }); } return formattedMessages; } /** * Merge call options with defaults */ mergeOptions(options) { return { ...this.getDefaultOptions(), ...options, temperature: options?.temperature ?? this.defaultTemperature, maxTokens: options?.maxTokens ?? this.defaultMaxTokens }; } async generate(messages, options) { const mergedOptions = this.mergeOptions(options); const formattedMessages = this.formatMessages(messages); // Handle callbacks const prompts = messages.map(m => m.content); await this.handleLLMStart(prompts); try { const requestBody = { contents: formattedMessages, generationConfig: { temperature: mergedOptions.temperature, maxOutputTokens: mergedOptions.maxTokens, topP: this.defaultTopP ?? 1.0, ...(this.defaultTopK && { topK: this.defaultTopK }), ...(mergedOptions.stop && mergedOptions.stop.length > 0 && { stopSequences: mergedOptions.stop }) } }; const response = await this.client.post(`/models/${this.model}:generateContent?key=${this.apiKey}`, requestBody, { timeout: mergedOptions.timeout, signal: mergedOptions.signal }); const candidate = response.data.candidates[0]; const content = candidate.content.parts.map(part => part.text).join(''); const result = { text: content, message: { type: 'ai', content: content, additionalKwargs: { safetyRatings: candidate.safetyRatings, finishReason: candidate.finishReason } }, llmOutput: { model: this.model, promptFeedback: response.data.promptFeedback } }; await this.handleLLMEnd(result); return result; } catch (error) { const err = error; await this.handleLLMError(err); // Provide more detailed error information if (error.response) { const errorData = error.response.data; console.error('\n🔴 Gemini API Error:', { status: error.response.status, message: errorData.error?.message || errorData.message, code: errorData.error?.code }); if (error.response.status === 429) { console.error('\n💡 Rate limit exceeded. Please check your API quota'); } else if (error.response.status === 400) { console.error('\n💡 Bad request. Check your API key and request format'); } } throw err; } } async *stream(messages, options) { // Gemini streaming requires different endpoint const mergedOptions = this.mergeOptions(options); const formattedMessages = this.formatMessages(messages); try { const requestBody = { contents: formattedMessages, generationConfig: { temperature: mergedOptions.temperature, maxOutputTokens: mergedOptions.maxTokens, topP: this.defaultTopP ?? 1.0, ...(this.defaultTopK && { topK: this.defaultTopK }), ...(mergedOptions.stop && mergedOptions.stop.length > 0 && { stopSequences: mergedOptions.stop }) } }; const response = await this.client.post(`/models/${this.model}:streamGenerateContent?key=${this.apiKey}`, requestBody, { responseType: 'stream', timeout: mergedOptions.timeout, signal: mergedOptions.signal }); let buffer = ''; for await (const chunk of response.data) { const lines = chunk.toString().split('\n'); for (const line of lines) { if (line.trim() === '') continue; try { const data = JSON.parse(line); if (data.candidates && data.candidates[0]) { const text = data.candidates[0].content.parts[0].text; yield { text, generationInfo: { finishReason: data.candidates[0].finishReason, safetyRatings: data.candidates[0].safetyRatings } }; } } catch (e) { // Handle partial JSON buffer += line; try { const data = JSON.parse(buffer); if (data.candidates && data.candidates[0]) { const text = data.candidates[0].content.parts[0].text; yield { text, generationInfo: { finishReason: data.candidates[0].finishReason, safetyRatings: data.candidates[0].safetyRatings } }; } buffer = ''; } catch { // Continue buffering } } } } } catch (error) { const err = error; await this.handleLLMError(err); throw err; } } } exports.Gemini = Gemini; /** * Create a Gemini instance from environment variables */ function createGeminiFromEnv(config) { const apiKey = process.env.GEMINI_API_KEY; if (!apiKey) { throw new Error('Gemini API key not found. Please set GEMINI_API_KEY environment variable.'); } return new Gemini({ apiKey, ...config }); } //# sourceMappingURL=gemini.js.map