@webgal-tools/voice
Version:
WebGAL GPT-SoVITS语音合成应用
435 lines • 17 kB
JavaScript
import { EventSource } from 'eventsource';
import fs from 'fs';
import path from 'path';
import { logger } from '@webgal-tools/logger';
// 语言选项映射(基于原代码中的 dict_language)
const LANGUAGE_OPTIONS = {
'中文': 'all_zh',
'英文': 'en',
'日文': 'all_ja',
'粤语': 'all_yue',
'韩文': 'all_ko',
'中英混合': 'zh',
'日英混合': 'ja',
'粤英混合': 'yue',
'韩英混合': 'ko',
'多语种混合': 'auto',
'多语种混合(粤语)': 'auto_yue'
};
// 文本切分选项
const TEXT_CUT_OPTIONS = {
'不切': 'no_cut',
'凑四句一切': 'cut1',
'凑50字一切': 'cut2',
'按中文句号。切': 'cut3',
'按英文句号.切': 'cut4',
'按标点符号切': 'cut5'
};
class GPTSoVITSAPI {
baseUrl;
sessionHash;
modelVersion;
constructor(baseUrl = 'http://localhost:9872', modelVersion = 'v2') {
this.baseUrl = baseUrl;
this.sessionHash = this.generateSessionHash();
this.modelVersion = {
version: modelVersion,
isV3V4: ['v3', 'v4'].includes(modelVersion)
};
}
generateSessionHash() {
return Math.random().toString(36).substring(2, 15);
}
/**
* 验证参数是否在允许范围内
*/
validateConfig(config) {
// 验证语言选项
if (config.prompt_language && !(config.prompt_language in LANGUAGE_OPTIONS)) {
throw new Error(`Invalid prompt_language. Must be one of: ${Object.keys(LANGUAGE_OPTIONS).join(', ')}`);
}
if (config.text_language && !(config.text_language in LANGUAGE_OPTIONS)) {
throw new Error(`Invalid text_language. Must be one of: ${Object.keys(LANGUAGE_OPTIONS).join(', ')}`);
}
// 验证文本切分选项
if (config.how_to_cut && !(config.how_to_cut in TEXT_CUT_OPTIONS)) {
throw new Error(`Invalid how_to_cut. Must be one of: ${Object.keys(TEXT_CUT_OPTIONS).join(', ')}`);
}
// 验证数值范围
if (config.top_k !== undefined && (config.top_k < 1 || config.top_k > 100)) {
throw new Error('top_k must be between 1 and 100');
}
if (config.top_p !== undefined && (config.top_p < 0 || config.top_p > 1)) {
throw new Error('top_p must be between 0 and 1');
}
if (config.temperature !== undefined && (config.temperature < 0 || config.temperature > 1)) {
throw new Error('temperature must be between 0 and 1');
}
if (config.speed !== undefined && (config.speed < 0.6 || config.speed > 1.65)) {
throw new Error('speed must be between 0.6 and 1.65');
}
if (config.pause_second !== undefined && (config.pause_second < 0.1 || config.pause_second > 0.5)) {
throw new Error('pause_second must be between 0.1 and 0.5');
}
// 验证采样步数
if (config.sample_steps !== undefined) {
const validSteps = this.modelVersion.isV3V4 ? [4, 8, 16, 32, 64, 128] : [4, 8, 16, 32];
if (!validSteps.includes(config.sample_steps)) {
throw new Error(`sample_steps must be one of: ${validSteps.join(', ')}`);
}
}
// v3暂不支持ref_text_free模式
if (this.modelVersion.isV3V4 && config.ref_text_free === true) {
throw new Error('ref_text_free mode is not supported in v3/v4 models');
}
// if_sr仅在v3模型中可用
if (config.if_sr === true && this.modelVersion.version !== 'v3') {
throw new Error('if_sr (super resolution) is only available in v3 model');
}
// 验证多参考音频(v3v4不支持)
if (config.inp_refs && config.inp_refs.length > 0 && this.modelVersion.isV3V4) {
throw new Error('Multiple reference audios (inp_refs) are not supported in v3/v4 models');
}
}
/**
* 验证音频文件
*/
validateAudioFile(audioPath) {
if (!fs.existsSync(audioPath)) {
throw new Error(`Audio file not found: ${audioPath}`);
}
const stats = fs.statSync(audioPath);
const supportedExtensions = ['.wav', '.mp3', '.flac', '.m4a'];
const ext = path.extname(audioPath).toLowerCase();
if (!supportedExtensions.includes(ext)) {
throw new Error(`Unsupported audio format. Supported formats: ${supportedExtensions.join(', ')}`);
}
// 检查文件大小(参考音频建议3-10秒,这里简单检查文件大小)
const maxSize = 10 * 1024 * 1024; // 10MB
if (stats.size > maxSize) {
console.warn('Warning: Audio file is large. Reference audio should be 3-10 seconds long.');
}
}
/**
* 发送请求到 Gradio 队列
*/
async sendToQueue(data, fnIndex, triggerId) {
const response = await fetch(`${this.baseUrl}/queue/join?__theme=light`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'Accept': '*/*',
},
body: JSON.stringify({
data,
event_data: null,
fn_index: fnIndex,
trigger_id: triggerId,
session_hash: this.sessionHash,
}),
});
if (!response.ok) {
throw new Error(`HTTP error! status: ${response.status}`);
}
const result = await response.json();
return result.event_id;
}
/**
* 通过 SSE 监听任务完成状态
*/
async waitForCompletion(eventId) {
return new Promise((resolve, reject) => {
const eventSource = new EventSource(`${this.baseUrl}/queue/data?session_hash=${this.sessionHash}`);
const timeout = setTimeout(() => {
eventSource.close();
reject(new Error('Timeout waiting for completion'));
}, 60000); // 60秒超时
eventSource.onmessage = (event) => {
try {
const data = JSON.parse(event.data);
if (data.event_id === eventId) {
if (data.msg === 'process_completed') {
clearTimeout(timeout);
eventSource.close();
if (data.success && data.output) {
resolve(data.output);
}
else {
reject(new Error('Process completed but failed'));
}
}
else if (data.msg === 'process_generating') {
console.log('Generation in progress...');
}
}
else if (data.msg === 'close_stream') {
clearTimeout(timeout);
eventSource.close();
}
}
catch (error) {
clearTimeout(timeout);
eventSource.close();
reject(new Error(`Failed to parse SSE data: ${error}`));
}
};
eventSource.onerror = (error) => {
clearTimeout(timeout);
eventSource.close();
reject(new Error(`SSE error: ${error}`));
};
});
}
/**
* 将音频文件上传到 Gradio
*/
async prepareAudioFile(audioPath) {
this.validateAudioFile(audioPath);
// 使用现有的 sessionHash 作为 upload_id
const uploadUrl = `${this.baseUrl}/upload?upload_id=${this.sessionHash}`;
// 创建 FormData
const formData = new FormData();
const fileName = path.basename(audioPath);
const fileBuffer = fs.readFileSync(audioPath);
const fileBlob = new Blob([fileBuffer], { type: this.getMimeType(path.extname(fileName)) });
formData.append('files', fileBlob, fileName);
// 上传文件
const response = await fetch(uploadUrl, {
method: 'POST',
body: formData,
});
if (!response.ok) {
throw new Error(`Upload failed: ${response.status} ${response.statusText}`);
}
const uploadedFiles = await response.json();
if (!uploadedFiles || uploadedFiles.length === 0) {
throw new Error('No file data returned from upload');
}
const uploadedFilePath = uploadedFiles[0];
logger.info("音频文件上传成功", [uploadedFilePath]);
// 构造返回对象,保持与原来格式一致
return {
path: uploadedFilePath.replace(/\\/g, '/'),
url: `${this.baseUrl}/file=${uploadedFilePath.replace(/\\/g, '/')}`,
orig_name: fileName,
size: fs.statSync(audioPath).size,
mime_type: this.getMimeType(path.extname(fileName)),
meta: {
_type: 'gradio.FileData'
}
};
}
/**
* 准备多个参考音频文件
*/
async prepareMultipleAudioFiles(audioPaths) {
if (this.modelVersion.isV3V4) {
throw new Error('Multiple reference audios are not supported in v3/v4 models');
}
const audioFiles = [];
for (const audioPath of audioPaths) {
const audioFile = await this.prepareAudioFile(audioPath);
audioFiles.push(audioFile);
}
return audioFiles;
}
/**
* 获取文件MIME类型
*/
getMimeType(extension) {
const mimeTypes = {
'.wav': 'audio/wav',
'.mp3': 'audio/mpeg',
'.flac': 'audio/flac',
'.m4a': 'audio/mp4'
};
return mimeTypes[extension.toLowerCase()] || 'audio/wav';
}
/**
* 设置 GPT 模型
*/
async setGptModel(modelName) {
try {
if (!modelName.endsWith('.ckpt')) {
throw new Error('GPT model name must end with .ckpt');
}
const eventId = await this.sendToQueue([modelName], 3, Date.now());
await this.waitForCompletion(eventId);
console.log(`Successfully loaded GPT model: ${modelName}`);
return true;
}
catch (error) {
logger.error(`Failed to set GPT model: ${error}`);
return false;
}
}
/**
* 设置 SoVITS 模型
*/
async setSovitsModel(modelName, promptLanguage = '日文', textLanguage = '日文') {
try {
if (!modelName.endsWith('.pth')) {
throw new Error('SoVITS model name must end with .pth');
}
if (!(promptLanguage in LANGUAGE_OPTIONS)) {
throw new Error(`Invalid promptLanguage. Must be one of: ${Object.keys(LANGUAGE_OPTIONS).join(', ')}`);
}
if (!(textLanguage in LANGUAGE_OPTIONS)) {
throw new Error(`Invalid textLanguage. Must be one of: ${Object.keys(LANGUAGE_OPTIONS).join(', ')}`);
}
const eventId = await this.sendToQueue([modelName, promptLanguage, textLanguage], 2, Date.now());
await this.waitForCompletion(eventId);
console.log(`Successfully loaded SoVITS model: ${modelName}`);
return true;
}
catch (error) {
logger.error(`Failed to set SoVITS model: ${error}`);
return false;
}
}
/**
* 生成语音
*/
async generateVoice(refVoicePath, refVoiceText, targetText, config = {}) {
try {
// 验证配置
this.validateConfig(config);
// 验证目标文本
if (!targetText || targetText.trim().length === 0) {
throw new Error('Target text cannot be empty');
}
// 准备主参考音频文件
const audioFileData = await this.prepareAudioFile(refVoicePath);
// 设置默认配置(根据原表单的默认值)
const defaultConfig = {
prompt_language: '日文',
text_language: '日文',
how_to_cut: '凑四句一切',
top_k: 15,
top_p: 1,
temperature: 1,
ref_text_free: false,
speed: 1,
if_freeze: false,
inp_refs: null,
sample_steps: this.modelVersion.isV3V4 ? (this.modelVersion.version === 'v3' ? 32 : 8) : 8,
if_sr: false,
pause_second: 0.3,
...config
};
// 准备多参考音频文件(如果有)
let inpRefsData = null;
if (defaultConfig.inp_refs && defaultConfig.inp_refs.length > 0) {
if (this.modelVersion.isV3V4) {
console.warn('Multiple reference audios will be ignored in v3/v4 models');
inpRefsData = null;
}
else {
inpRefsData = await this.prepareMultipleAudioFiles(defaultConfig.inp_refs);
}
}
// 构建请求数据数组(按照原函数参数顺序)
const requestData = [
audioFileData, // inp_ref
refVoiceText, // prompt_text
defaultConfig.prompt_language, // prompt_language
targetText, // text
defaultConfig.text_language, // text_language
defaultConfig.how_to_cut, // how_to_cut
defaultConfig.top_k, // top_k
defaultConfig.top_p, // top_p
defaultConfig.temperature, // temperature
defaultConfig.ref_text_free, // ref_text_free
defaultConfig.speed, // speed
defaultConfig.if_freeze, // if_freeze
inpRefsData, // inp_refs
defaultConfig.sample_steps, // sample_steps
defaultConfig.if_sr, // if_sr_Checkbox
defaultConfig.pause_second // pause_second_slider
];
console.log('Starting voice generation...');
console.log('Config:', defaultConfig);
const eventId = await this.sendToQueue(requestData, 1, Date.now());
const result = await this.waitForCompletion(eventId);
if (result.data && result.data.length > 0) {
const outputAudio = result.data[0];
console.log(`Voice generation completed: ${outputAudio.path}`);
return outputAudio.path;
}
else {
throw new Error('No audio data returned from generation');
}
}
catch (error) {
logger.error(`Voice generation failed: ${error}`);
throw error;
}
}
/**
* 获取可用的语言选项
*/
static getLanguageOptions() {
return Object.keys(LANGUAGE_OPTIONS);
}
/**
* 获取可用的文本切分选项
*/
static getTextCutOptions() {
return Object.keys(TEXT_CUT_OPTIONS);
}
/**
* 获取模型版本允许的采样步数
*/
getSampleStepsOptions() {
return this.modelVersion.isV3V4 ? [4, 8, 16, 32, 64, 128] : [4, 8, 16, 32];
}
/**
* 获取音频文件的下载URL
*/
getAudioDownloadUrl(audioPath) {
return `${this.baseUrl}/file=${audioPath.replace(/\\/g, '/')}`;
}
/**
* 下载生成的音频文件
*/
async downloadAudio(audioPath, outputPath) {
const url = this.getAudioDownloadUrl(audioPath);
const response = await fetch(url);
if (!response.ok) {
throw new Error(`Failed to download audio: ${response.status}`);
}
const arrayBuffer = await response.arrayBuffer();
const buffer = Buffer.from(arrayBuffer);
fs.writeFileSync(outputPath, buffer);
console.log(`Audio saved to: ${outputPath}`);
}
}
// 导出函数形式的 API
let apiInstance = null;
function getAPIInstance(modelVersion = 'v2') {
if (!apiInstance) {
apiInstance = new GPTSoVITSAPI('http://localhost:9872', modelVersion);
}
return apiInstance;
}
/**
* 设置 GPT 模型
*/
export async function set_gpt_model(modelName) {
return await getAPIInstance().setGptModel(modelName);
}
/**
* 设置 SoVITS 模型
*/
export async function set_sovits_model(modelName, promptLanguage = '日文', textLanguage = '日文') {
return await getAPIInstance().setSovitsModel(modelName, promptLanguage, textLanguage);
}
/**
* 生成语音
*/
export async function generate_voice(refVoicePath, refVoiceText, targetText, config = {}) {
return await getAPIInstance().generateVoice(refVoicePath, refVoiceText, targetText, config);
}
// 导出类型和常量
export { GPTSoVITSAPI, LANGUAGE_OPTIONS, TEXT_CUT_OPTIONS };
//# sourceMappingURL=request.js.map