gpt-sovits-sdk
Version:
Node.js SDK for GPT-SoVITS API
390 lines (389 loc) • 15.4 kB
JavaScript
;
/**
* GPT-SoVITS SDK 客户端实现
*/
var __importDefault = (this && this.__importDefault) || function (mod) {
return (mod && mod.__esModule) ? mod : { "default": mod };
};
Object.defineProperty(exports, "__esModule", { value: true });
exports.GPTSoVITSClient = void 0;
const types_1 = require("./types");
const form_data_1 = __importDefault(require("form-data"));
const promises_1 = require("fs/promises");
/**
* 默认客户端配置
*/
const DEFAULT_OPTIONS = {
baseUrl: 'http://127.0.0.1:9880',
timeout: 30000,
debug: false,
retries: 0
};
/**
* GPT-SoVITS API 客户端
*/
class GPTSoVITSClient {
/**
* 创建GPT-SoVITS客户端实例
* @param options 客户端配置选项
*/
constructor(options = {}) {
this.options = { ...DEFAULT_OPTIONS, ...options };
this.controller = new AbortController();
this.log('GPT-SoVITS客户端已初始化', this.options);
}
/**
* 记录调试信息
* @param message 消息
* @param data 数据
*/
log(message, ...data) {
if (this.options.debug) {
console.log(`[GPT-SoVITS] ${message}`, ...data);
}
}
/**
* 构建完整URL
* @param endpoint API端点
* @param params 查询参数
* @returns 完整URL
*/
buildUrl(endpoint, params) {
const url = new URL(endpoint, this.options.baseUrl);
if (params) {
Object.entries(params).forEach(([key, value]) => {
if (value !== undefined && value !== null) {
if (Array.isArray(value)) {
value.forEach(item => url.searchParams.append(key, String(item)));
}
else {
url.searchParams.append(key, String(value));
}
}
});
}
return url.toString();
}
/**
* 发送HTTP请求
* @param endpoint API端点
* @param options 请求选项
* @returns 响应数据
*/
async request(endpoint, options = {}) {
const { method = 'GET', params, body, headers = {}, responseType = 'json', retries = this.options.retries } = options;
const url = this.buildUrl(endpoint, params);
let currentRetry = 0;
// 设置超时控制
this.controller = new AbortController();
const timeoutId = setTimeout(() => {
this.controller.abort();
}, this.options.timeout);
try {
// 准备请求头
const requestHeaders = {
...this.options.headers,
...headers
};
// 准备请求体
let requestBody = undefined;
if (body) {
if (body instanceof form_data_1.default) {
requestBody = body;
// FormData会自动设置Content-Type和boundary
}
else if (typeof body === 'object') {
requestBody = JSON.stringify(body);
requestHeaders['Content-Type'] = 'application/json';
}
else {
requestBody = body;
}
}
this.log(`发送${method}请求到: ${url}`, { headers: requestHeaders });
// 执行请求,支持重试
while (true) {
try {
const response = await fetch(url, {
method,
headers: requestHeaders,
body: requestBody,
signal: this.controller.signal
});
// 检查响应状态
if (!response.ok) {
let errorData;
try {
if (response.headers.get('content-type')?.includes('application/json')) {
errorData = await response.json();
}
else {
errorData = await response.text();
}
}
catch (e) {
errorData = `HTTP错误: ${response.status} ${response.statusText}`;
}
throw new types_1.GPTSoVITSError(typeof errorData === 'object'
? errorData.message || JSON.stringify(errorData)
: errorData, {
code: response.status,
url,
method
});
}
// 根据响应类型处理响应
if (responseType === 'json') {
if (response.headers.get('content-type')?.includes('application/json')) {
return await response.json();
}
else {
this.log('警告: 响应不是JSON格式,但请求了JSON类型');
return await response.text();
}
}
else if (responseType === 'arraybuffer') {
return await response.arrayBuffer();
}
else {
return await response.text();
}
}
catch (error) {
// 如果是中止错误,直接抛出
if (error instanceof DOMException && error.name === 'AbortError') {
throw new types_1.GPTSoVITSError('请求超时', {
url,
method
});
}
// 如果还有重试次数,则重试
if (currentRetry < retries) {
currentRetry++;
this.log(`请求失败,正在重试 (${currentRetry}/${retries})`, error);
// 指数退避策略
await new Promise(resolve => setTimeout(resolve, 1000 * Math.pow(2, currentRetry - 1)));
continue;
}
// 重试次数用完,抛出错误
if (error instanceof types_1.GPTSoVITSError) {
throw error;
}
else {
throw new types_1.GPTSoVITSError(error instanceof Error ? error.message : String(error), {
cause: error,
url,
method
});
}
}
}
}
finally {
clearTimeout(timeoutId);
}
}
/**
* 获取API根信息
* @returns API根信息
*/
async getRoot() {
return this.request('/');
}
/**
* 获取API健康状态
* @returns 健康状态信息
*/
async getHealth() {
return this.request('/health');
}
/**
* 获取可用模型列表
* @returns 模型列表
*/
async getModels() {
return this.request('/api/auxiliary/models');
}
/**
* 获取参考音频列表
* @param subdir 子目录名称
* @returns 参考音频列表
*/
async getReferenceAudios(subdir) {
return this.request('/api/auxiliary/reference_audios', {
params: subdir ? { subdir } : undefined
});
}
/**
* 获取情感音频列表
* @param options 查询选项
* @returns 情感音频列表
*/
async getEmotionAudios(options = {}) {
return this.request('/api/auxiliary/emotion_audios', {
params: options
});
}
/**
* 获取情感参考音频列表
* @returns 情感参考音频列表
*/
async getEmotionReferenceAudios() {
return this.request('/api/auxiliary/emotion_reference_audios');
}
/**
* 规范化选项,支持驼峰命名法和下划线命名法
* @param options 原始选项
* @returns 规范化后的选项
*/
normalizeOptions(options) {
if (!options)
return {};
// 创建一个新对象,保留原始选项
return { ...options };
}
/**
* 文本转语音
* @param options TTS选项
* @returns TTS响应
*/
async textToSpeech(options) {
// 处理参数,支持驼峰命名法和下划线命名法
const normalizedOptions = this.normalizeOptions(options);
// 转换参数名称为API格式
const payload = {
text: normalizedOptions.text,
text_lang: normalizedOptions.textLang || normalizedOptions.text_lang,
ref_audio_path: normalizedOptions.refAudioPath || normalizedOptions.ref_audio_path,
prompt_lang: normalizedOptions.promptLang || normalizedOptions.prompt_lang,
prompt_text: normalizedOptions.promptText || normalizedOptions.prompt_text,
aux_ref_audio_paths: normalizedOptions.auxRefAudioPaths || normalizedOptions.aux_ref_audio_paths,
gpt_model: normalizedOptions.gptModel || normalizedOptions.gpt_model,
sovits_model: normalizedOptions.sovitsModel || normalizedOptions.sovits_model,
top_k: normalizedOptions.topK || normalizedOptions.top_k,
top_p: normalizedOptions.topP || normalizedOptions.top_p,
temperature: normalizedOptions.temperature,
text_split_method: normalizedOptions.textSplitMethod || normalizedOptions.text_split_method,
batch_size: normalizedOptions.batchSize || normalizedOptions.batch_size,
batch_threshold: normalizedOptions.batchThreshold || normalizedOptions.batch_threshold,
split_bucket: normalizedOptions.splitBucket || normalizedOptions.split_bucket,
speed_factor: normalizedOptions.speedFactor || normalizedOptions.speed_factor,
fragment_interval: normalizedOptions.fragmentInterval || normalizedOptions.fragment_interval,
seed: normalizedOptions.seed,
parallel_infer: normalizedOptions.parallelInfer || normalizedOptions.parallel_infer,
repetition_penalty: normalizedOptions.repetitionPenalty || normalizedOptions.repetition_penalty,
media_type: normalizedOptions.mediaType || normalizedOptions.media_type,
streaming_mode: normalizedOptions.streamingMode || normalizedOptions.streaming_mode
};
return this.request('/api/core/tts', {
method: 'POST',
body: payload,
headers: {
'Content-Type': 'application/json'
}
});
}
/**
* 文本转语音(直接返回音频数据)
* @param options TTS选项
* @param outputPath 可选的输出文件路径
* @returns 音频数据或文件路径
*/
async textToSpeechDirect(options, outputPath) {
// 处理参数,支持驼峰命名法和下划线命名法
const normalizedOptions = this.normalizeOptions(options);
// 转换参数名称为API格式
const payload = {
text: normalizedOptions.text,
text_lang: normalizedOptions.textLang || normalizedOptions.text_lang,
ref_audio_path: normalizedOptions.refAudioPath || normalizedOptions.ref_audio_path,
prompt_lang: normalizedOptions.promptLang || normalizedOptions.prompt_lang,
prompt_text: normalizedOptions.promptText || normalizedOptions.prompt_text,
aux_ref_audio_paths: normalizedOptions.auxRefAudioPaths || normalizedOptions.aux_ref_audio_paths,
gpt_model: normalizedOptions.gptModel || normalizedOptions.gpt_model,
sovits_model: normalizedOptions.sovitsModel || normalizedOptions.sovits_model,
top_k: normalizedOptions.topK || normalizedOptions.top_k,
top_p: normalizedOptions.topP || normalizedOptions.top_p,
temperature: normalizedOptions.temperature,
text_split_method: normalizedOptions.textSplitMethod || normalizedOptions.text_split_method,
batch_size: normalizedOptions.batchSize || normalizedOptions.batch_size,
batch_threshold: normalizedOptions.batchThreshold || normalizedOptions.batch_threshold,
split_bucket: normalizedOptions.splitBucket || normalizedOptions.split_bucket,
speed_factor: normalizedOptions.speedFactor || normalizedOptions.speed_factor,
fragment_interval: normalizedOptions.fragmentInterval || normalizedOptions.fragment_interval,
seed: normalizedOptions.seed,
parallel_infer: normalizedOptions.parallelInfer || normalizedOptions.parallel_infer,
repetition_penalty: normalizedOptions.repetitionPenalty || normalizedOptions.repetition_penalty,
media_type: normalizedOptions.mediaType || normalizedOptions.media_type,
streaming_mode: false // 直接API不支持流式模式
};
const audioData = await this.request('/api/core/tts_direct', {
method: 'POST',
body: payload,
headers: {
'Content-Type': 'application/json'
},
responseType: 'arraybuffer'
});
// 如果提供了输出路径,保存音频数据到文件
if (outputPath) {
await (0, promises_1.writeFile)(outputPath, Buffer.from(audioData));
return outputPath;
}
return audioData;
}
/**
* 设置GPT模型
* @param modelPath 模型路径
* @returns 响应结果
*/
async setGptModel(modelPath) {
return this.request('/api/core/set_gpt_model', {
method: 'POST',
body: { model_path: modelPath },
headers: {
'Content-Type': 'application/json'
}
});
}
/**
* 设置SoVITS模型
* @param modelPath 模型路径
* @returns 响应结果
*/
async setSovitsModel(modelPath) {
return this.request('/api/core/set_sovits_model', {
method: 'POST',
body: { model_path: modelPath },
headers: {
'Content-Type': 'application/json'
}
});
}
/**
* 获取生成的音频文件
* @param audioName 音频文件名
* @param outputPath 可选的输出文件路径
* @returns 音频数据或文件路径
*/
async getAudio(audioName, outputPath) {
const audioData = await this.request(`/api/core/audio/${encodeURIComponent(audioName)}`, {
responseType: 'arraybuffer'
});
// 如果提供了输出路径,保存音频数据到文件
if (outputPath) {
await (0, promises_1.writeFile)(outputPath, Buffer.from(audioData));
return outputPath;
}
return audioData;
}
/**
* 取消所有正在进行的请求
*/
abort() {
this.controller.abort();
this.controller = new AbortController();
this.log('已取消所有正在进行的请求');
}
}
exports.GPTSoVITSClient = GPTSoVITSClient;