@neuroequalityorg/knightcode
Version:
Knightcode CLI - Your local AI coding assistant using Ollama, LM Studio, and more
387 lines (334 loc) • 10.5 kB
text/typescript
/**
* Ollama Client
*
* Handles interaction with Ollama API for text completion
* and code assistance features.
*/
import { logger } from '../utils/logger.js';
import { createUserError } from '../errors/formatter.js';
import { ErrorCategory } from '../errors/types.js';
import { withTimeout, withRetry } from '../utils/async.js';
import { AIProvider, Message, CompletionOptions, CompletionResponse, StreamEvent } from './provider.js';
interface CompletionRequest {
model: string;
messages: Message[];
temperature?: number;
max_tokens?: number;
top_p?: number;
top_k?: number;
stop_sequences?: string[];
stream?: boolean;
system?: string;
}
interface ApiError {
error?: {
message?: string;
};
}
// Default API configuration
const DEFAULT_CONFIG = {
apiBaseUrl: 'http://localhost:11434',
timeout: 60000, // 60 seconds
retryOptions: {
maxRetries: 3,
initialDelayMs: 1000,
maxDelayMs: 10000
},
defaultModel: 'devstral:24b',
defaultMaxTokens: 4096,
defaultTemperature: 0.7
};
/**
* Ollama AI client for interacting with Ollama API
*/
export class OllamaClient implements AIProvider {
private config: typeof DEFAULT_CONFIG;
/**
* Create a new Ollama client
*/
constructor(config: Partial<typeof DEFAULT_CONFIG> = {}) {
this.config = { ...DEFAULT_CONFIG, ...config };
logger.debug('Ollama client created with config', {
apiBaseUrl: this.config.apiBaseUrl,
defaultModel: this.config.defaultModel
});
}
/**
* Format API request headers
*/
private getHeaders(): Record<string, string> {
return {
'Content-Type': 'application/json'
};
}
/**
* Send a completion request to Ollama
*/
async complete(options: CompletionOptions): Promise<CompletionResponse> {
logger.debug('Sending completion request', { model: options.model || this.config.defaultModel });
// Format the request
const messages: Message[] = options.messages;
const request = {
model: options.model || this.config.defaultModel,
prompt: messages.map(m => m.content).join('\n'),
stream: false,
options: {
temperature: options.temperature ?? this.config.defaultTemperature,
num_predict: options.maxTokens || this.config.defaultMaxTokens,
top_p: options.topP,
top_k: options.topK,
stop: options.stopSequences
}
};
// Make the API request with timeout and retry
try {
const sendRequestWithPath = async (path: string, requestOptions: RequestInit) => {
return this.sendRequest(path, requestOptions);
};
const timeoutFn = withTimeout(sendRequestWithPath, this.config.timeout);
const retryFn = withRetry(timeoutFn, {
maxRetries: this.config.retryOptions.maxRetries,
initialDelayMs: this.config.retryOptions.initialDelayMs,
maxDelayMs: this.config.retryOptions.maxDelayMs
});
const response = await retryFn('/api/generate', {
method: 'POST',
headers: this.getHeaders(),
body: JSON.stringify(request)
});
// Convert Ollama response to our CompletionResponse format
return {
id: Date.now().toString(),
model: response.model,
usage: {
input_tokens: response.prompt_eval_count || 0,
output_tokens: response.eval_count || 0
},
content: [{
type: 'text',
text: response.response
}]
};
} catch (error: unknown) {
let errorMessage = 'Unknown error';
if (error && typeof error === 'object' && 'error' in error) {
const err = error as { error?: { message?: string } };
errorMessage = err.error?.message || 'Unknown error';
}
throw new Error(errorMessage);
}
}
/**
* Send a streaming completion request to Ollama
*/
async *completeStream(options: CompletionOptions): AsyncGenerator<StreamEvent> {
logger.debug('Sending streaming completion request', { model: options.model || this.config.defaultModel });
// Format the request
const messages: Message[] = options.messages;
const request = {
model: options.model || this.config.defaultModel,
prompt: messages.map(m => m.content).join('\n'),
stream: true,
options: {
temperature: options.temperature ?? this.config.defaultTemperature,
num_predict: options.maxTokens || this.config.defaultMaxTokens,
top_p: options.topP,
top_k: options.topK,
stop: options.stopSequences
}
};
// For now, return a simple stream implementation
// In a full implementation, this would handle actual streaming
try {
const response = await this.complete(options);
yield {
type: 'message_start',
message: {
id: response.id,
model: response.model,
content: response.content,
stop_reason: response.stop_reason
}
};
yield {
type: 'message_stop',
message: {
id: response.id,
model: response.model,
content: response.content,
stop_reason: response.stop_reason
}
};
} catch (error) {
logger.error('Ollama streaming failed:', error);
throw error;
}
}
/**
* Type guard for API error response
*/
private isApiErrorResponse(error: unknown): error is { error?: { message?: string } } {
return (
typeof error === 'object' &&
error !== null &&
'error' in error &&
typeof (error as any).error === 'object' &&
(error as any).error !== null
);
}
/**
* Test the connection to the Ollama API
*/
async testConnection(): Promise<boolean> {
logger.debug('Testing connection to Ollama API');
try {
// First try to get available models (lighter request)
const models = await this.getModels();
if (models.length > 0) {
logger.debug('Connection test successful - found models:', models);
return true;
}
// If no models found, try a simple health check
const response = await fetch(`${this.config.apiBaseUrl}/api/tags`, {
method: 'GET',
headers: this.getHeaders()
});
if (response.ok) {
logger.debug('Connection test successful - API responded');
return true;
}
logger.debug('Connection test failed - API responded with status:', response.status);
return false;
} catch (error: unknown) {
logger.debug('Connection test failed with error:', error);
return false;
}
}
/**
* Handle error responses from the API
*/
private async handleErrorResponse(response: Response): Promise<never> {
let errorMessage: string;
let errorDetails: any;
try {
const errorResponse = await response.json() as ApiError;
errorMessage = errorResponse.error?.message || 'Unknown error';
errorDetails = errorResponse;
} catch {
errorMessage = `HTTP error ${response.status}`;
errorDetails = { status: response.status };
}
throw createUserError(`Ollama API error: ${errorMessage}`, {
category: ErrorCategory.AI_SERVICE,
details: errorDetails
});
}
/**
* Send a request to the Ollama API
*/
private async sendRequest(path: string, options: RequestInit): Promise<any> {
const url = `${this.config.apiBaseUrl}${path}`;
const response = await fetch(url, options);
if (!response.ok) {
await this.handleErrorResponse(response);
}
return response.json();
}
/**
* Send a streaming request to the Ollama API
*/
private async sendStreamRequest(
path: string,
options: RequestInit,
onEvent: (event: StreamEvent) => void
): Promise<void> {
const url = `${this.config.apiBaseUrl}${path}`;
const response = await fetch(url, options);
if (!response.ok) {
await this.handleErrorResponse(response);
}
const reader = response.body?.getReader();
if (!reader) {
throw new Error('Failed to get response reader');
}
const decoder = new TextDecoder();
let buffer = '';
try {
while (true) {
const { done, value } = await reader.read();
if (done) {
break;
}
buffer += decoder.decode(value, { stream: true });
// Process complete JSON objects
let newlineIndex;
while ((newlineIndex = buffer.indexOf('\n')) !== -1) {
const line = buffer.slice(0, newlineIndex);
buffer = buffer.slice(newlineIndex + 1);
if (line.trim()) {
try {
const data = JSON.parse(line);
// Convert Ollama stream event to our StreamEvent format
const event: StreamEvent = {
type: 'content_block_delta',
delta: {
type: 'text',
text: data.response
}
};
onEvent(event);
} catch (error) {
logger.error('Failed to parse stream event', error);
}
}
}
}
} finally {
reader.releaseLock();
}
}
/**
* Get available models from Ollama
*/
async getModels(): Promise<string[]> {
try {
const response = await this.sendRequest('/api/tags', { method: 'GET' });
return response.models?.map((model: any) => model.name) || [];
} catch (error) {
logger.error('Failed to get Ollama models:', error);
return [];
}
}
/**
* Set the model to use
*/
setModel(model: string): void {
this.config.defaultModel = model;
logger.debug('Ollama model set to:', model);
}
/**
* Get the current model
*/
getModel(): string {
return this.config.defaultModel;
}
/**
* Get configuration information
*/
getConfig(): any {
return { ...this.config };
}
/**
* Update configuration
*/
updateConfig(newConfig: Partial<typeof DEFAULT_CONFIG>): void {
this.config = { ...this.config, ...newConfig };
logger.debug('Ollama client configuration updated:', this.config);
}
/**
* Get provider name
*/
getProviderName(): string {
return 'ollama';
}
}