UNPKG

@aristech-org/tts-client

Version:

A Node.js client library for the Aristech Text-to-Speech API

234 lines (233 loc) 9.01 kB
import * as grpc from '@grpc/grpc-js'; import { SpeechAudioFormat_Codec, SpeechAudioFormat_Container } from './generated/TTSTypes.js'; import { PhonesetRequest, SpeechRequest, SpeechServiceClient, TranscriptionRequest, VoiceListRequest } from './generated/TTSServices.js'; import fs from 'fs'; export * from './generated/TTSTypes.js'; export { PhonesetRequest, PhonesetResponse, SpeechRequest, SpeechResponse, SpeechServiceClient, TranscriptionRequest, TranscriptionResponse, VoiceListRequest } from './generated/TTSServices.js'; export class TtsClient { cOptions; constructor(options) { this.cOptions = options; } /** * Lists all available voices. */ listVoices(request) { return new Promise((res, rej) => { const client = this.getClient(); const req = VoiceListRequest.create(request); const stream = client.getVoiceList(req); const voices = []; stream.on('data', (voice) => { voices.push(voice); }); stream.on('end', () => { res(voices); }); stream.on('error', (err) => { rej(err); }); }); } /** * Creates a stream of audio data from the given text. * @param request The request object * @returns A tuple containing the stream and the voice used for the audio generation */ streamAudio(request) { return new Promise(async (res, rej) => { const client = this.getClient(); const req = SpeechRequest.create({ ...request, inputType: 'SSML', outputType: 'AUDIO', options: { ...request.options, audio: { ...request.options?.audio, codec: SpeechAudioFormat_Codec.PCM, container: SpeechAudioFormat_Container.WAV, } } }); // Get the voiceId from the request const voiceId = req.options?.voiceId; if (!voiceId) { rej(new Error('voiceId is required')); return; } // Get the voice audio specs const voices = await this.listVoices(); const voice = voices.find((v) => v.voiceId === voiceId); if (!voice) { rej(new Error(`Voice with id "${voiceId}" not found`)); return; } const stream = client.getSpeech(req); res([stream, voice]); }); } /** * Creates an audio buffer from the given text. * @param request The request object * @returns The audio buffer */ synthesize(request) { return new Promise(async (res, rej) => { const [stream, voice] = await this.streamAudio(request); const rawChunks = []; stream.on('data', (msg) => { const chunk = Buffer.from(msg.data); rawChunks.push(chunk); }); stream.on('end', () => { const audioBuffer = Buffer.concat(rawChunks); const requestedFormat = request?.options?.audio?.container || SpeechAudioFormat_Container.WAV; // Wave headers need to be prepended to the audio data // because while generating the audio data, the data length is not known if (requestedFormat === SpeechAudioFormat_Container.WAV) { const sampleRate = voice.audio.samplerate; const bitsPerSample = voice.audio.bitrate; const header = createWaveHeader(audioBuffer.length, sampleRate, 1, bitsPerSample); res(Buffer.concat([header, audioBuffer])); return; } res(audioBuffer); }); stream.on('error', (err) => { rej(err); }); }); } /** * This is an alias for the `synthesize` method. */ audioBuffer(request) { return this.synthesize(request); } /** * Retrieves the phoneset for the given voice. * @param request The request object * @returns The phoneset response */ getPhoneset(request) { return new Promise((res, rej) => { const client = this.getClient(); const req = PhonesetRequest.create(request); client.getPhoneset(req, (err, response) => { if (err) { rej(err); return; } res(response); }); }); } /** * Retrieves the transcription for the given request. * @param request The request object * @returns The transcription response */ getTranscription(request) { return new Promise((res, rej) => { const client = this.getClient(); const req = TranscriptionRequest.create(request); client.getTranscription(req, (err, response) => { if (err) { rej(err); return; } res(response); }); }); } getClient() { const { rootCert: rootCertPath, rootCertContent, auth, grpcClientOptions } = this.cOptions; let host = this.cOptions.host || 'localhost:8423'; let ssl = this.cOptions.ssl === true; let rootCert = null; if (rootCertContent) { rootCert = Buffer.from(rootCertContent); } else if (rootCertPath) { rootCert = fs.readFileSync(rootCertPath); } const sslExplicit = typeof this.cOptions.ssl === 'boolean' || !!rootCert; const portRe = /[^:]+:([0-9]+)$/; if (portRe.test(host)) { // In case a port was provided but ssl was not specified // ssl is assumed when the port matches 8424 const [, portStr] = host.match(portRe); const hostPort = parseInt(portStr, 10); if (!sslExplicit) { if (hostPort === 8424) { ssl = true; } else { ssl = false; } } } else { // In case no port was provided, depending on the ssl settings // at the default non ssl port 8423 or ssl port 8424 if (sslExplicit && ssl) { host = `${host}:8424`; } else { host = `${host}:8423`; } } let creds = grpc.credentials.createInsecure(); if (ssl || rootCert) { creds = grpc.credentials.createSsl(rootCert); if (auth) { const callCreds = grpc.credentials.createFromMetadataGenerator((_, cb) => { const meta = new grpc.Metadata(); meta.add('token', auth.token); meta.add('secret', auth.secret); cb(null, meta); }); creds = grpc.credentials.combineChannelCredentials(creds, callCreds); } } return new SpeechServiceClient(host, creds, grpcClientOptions); } } /** * A helper function to create a WAV header for the given audio data. * @param dataLength The length of the audio data in bytes * @param sampleRate Sample rate in Hz * @param numChannels Number of channels * @param bitsPerSample Bits per sample * @returns The WAV header as a buffer */ export function createWaveHeader(dataLength, sampleRate, numChannels, bitsPerSample) { const byteRate = (sampleRate * numChannels * bitsPerSample) / 8; const blockAlign = (numChannels * bitsPerSample) / 8; const headerSize = 44; // Standard WAV header size const buffer = new ArrayBuffer(headerSize); const view = new DataView(buffer); // RIFF chunk descriptor writeString(view, 0, "RIFF"); // ChunkID view.setUint32(4, 36 + dataLength, true); // ChunkSize = 36 + dataLength writeString(view, 8, "WAVE"); // Format // "fmt " sub-chunk writeString(view, 12, "fmt "); // Subchunk1ID view.setUint32(16, 16, true); // Subchunk1Size (PCM = 16) view.setUint16(20, 1, true); // AudioFormat (PCM = 1) view.setUint16(22, numChannels, true); // NumChannels (Mono = 1) view.setUint32(24, sampleRate, true); // SampleRate view.setUint32(28, byteRate, true); // ByteRate view.setUint16(32, blockAlign, true); // BlockAlign view.setUint16(34, bitsPerSample, true); // BitsPerSample // "data" sub-chunk writeString(view, 36, "data"); // Subchunk2ID view.setUint32(40, dataLength, true); // Subchunk2Size = dataLength return Buffer.from(buffer); } function writeString(view, offset, str) { for (let i = 0; i < str.length; i++) { view.setUint8(offset + i, str.charCodeAt(i)); } }