@moonshine-ai/moonshine-js
Version:
On-device speech-to-text and voice control for web applications with Moonshine.
479 lines (435 loc) • 17.7 kB
text/typescript
import { Settings } from "./constants";
import MoonshineModel from "./model";
import MoonshineError from "./error";
import { AudioNodeVAD } from "@ricky0123/vad-web";
import Log from "./log";
/**
* Callbacks are invoked at different phases of the lifecycle as audio is transcribed. You can control the behavior of the application
* in response to model loading, starting of transcription, stopping of transcription, and updates to the transcription of the audio stream.
*
* @property onPermissionsRequested() - called when permissions to a user resource (e.g., microphone) have been requested (but not necessarily granted yet)
*
* @property onError(error: {@link MoonshineError}) - called when an error occurs.
*
* @property onModelLoadStarted() - called when the {@link MoonshineModel} and VAD begins to load (or download, if hosted elsewhere)
*
* @property onModelLoaded() - called when the {@link MoonshineModel} and VAD are loaded. This means the Transcriber is now ready to use.
*
* @property onTranscribeStarted() - called once when transcription starts
*
* @property onTranscribeStopped() - called once when transcription stops
*
* @property onTranscriptionUpdated(text: string) - when `useVAD === false` (i.e., streaming mode), this callback is invoked on a rapid
* interval ({@link Settings.STREAM_UPDATE_INTERVAL}), with the speculative transcription of the audio.
*
* @property onTranscriptionCommitted(text: string, buffer?: AudioBuffer) - called every time a transcript is "committed"; when `useVAD === false` (streaming mode),
* the transcript is committed between brief pauses in speech. When `useVAD === true`, the transcript is committed after speech events, or during brief pauses in long speech events.
*
* @property onFrame(probability, frame, ema) - called every frame of audio
*
* @property onSpeechStart() - called when the VAD model detects the start of speech
*
* @property onSpeechEnd() - called when the VAD model detects the end of speech
*
* @interface
*/
interface TranscriberCallbacks {
onPermissionsRequested: () => any;
onError: (error) => any;
onModelLoadStarted: () => any;
onModelLoaded: () => any;
onTranscribeStarted: () => any;
onTranscribeStopped: () => any;
onTranscriptionUpdated: (text: string) => any;
onTranscriptionCommitted: (text: string, buffer?: AudioBuffer) => any;
onFrame: (probs, frame, ema) => any;
onSpeechStart: () => any;
onSpeechEnd: () => any;
}
const defaultTranscriberCallbacks: TranscriberCallbacks = {
onPermissionsRequested: function () {
Log.log("Transcriber.onPermissionsRequested()");
},
onError: function (error) {
Log.error("Transcriber.onError(" + error + ")");
},
onModelLoadStarted: function () {
Log.log("Transcriber.onModelLoadStarted()");
},
onModelLoaded: function () {
Log.log("Transcriber.onModelLoaded()");
},
onTranscribeStarted: function () {
Log.log("Transcriber.onTranscribeStarted()");
},
onTranscribeStopped: function () {
Log.log("Transcriber.onTranscribeStopped()");
},
onTranscriptionUpdated: function (text: string) {
Log.log("Transcriber.onTranscriptionUpdated(" + text + ")");
},
onTranscriptionCommitted: function (text: string, buffer?: AudioBuffer) {
Log.log("Transcriber.onTranscriptionCommitted(" + text + ")");
},
onFrame: function (probs, frame, ema) {
Log.log("Transcriber.onFrame()");
},
onSpeechStart: function () {
Log.log("Transcriber.onSpeechStart()");
},
onSpeechEnd: function () {
Log.log("Transcriber.onSpeechEnd()");
},
};
class SpeechBuffer {
private buffer: Float32Array;
private frameCount: number;
public frameEMA: number;
private speechEMA: (value: number) => any;
private useVAD: boolean;
constructor(useVAD: boolean) {
this.useVAD = useVAD;
this.flush();
}
public flush(): void {
this.buffer = new Float32Array(
this.maxCommitInterval() * Settings.FRAME_SIZE
);
this.speechEMA = this.ema(Settings.STREAM_COMMIT_EMA_PERIOD);
this.frameEMA = 0.0;
this.frameCount = 0;
}
public set(frame, p = undefined): void {
this.buffer.set(frame, this.frameCount * Settings.FRAME_SIZE);
if (p) this.updateEMA(p);
this.frameCount += 1;
}
public updateEMA(p): void {
this.frameEMA = this.speechEMA(p.isSpeech);
}
public subarray(): Float32Array {
return this.buffer.subarray(0, this.frameCount * Settings.FRAME_SIZE);
}
public copy(): Float32Array {
return this.buffer.slice(0, this.frameCount * Settings.FRAME_SIZE);
}
public hasFrames(): boolean {
return this.frameCount > 0;
}
public shouldSet(): boolean {
return this.frameCount <= this.maxCommitInterval();
}
public shouldUpdate(): boolean {
return (
this.frameCount < this.maxCommitInterval() &&
this.frameCount % Settings.STREAM_UPDATE_INTERVAL === 0
);
}
public shouldCommit(): boolean {
if (
this.frameEMA <= 0.5 &&
this.frameCount >= this.minCommitInterval() &&
this.frameCount < this.maxCommitInterval()
) {
Log.log(`Speech pause, frameCount: ${this.frameCount}`);
} else if (this.frameCount === this.maxCommitInterval()) {
Log.log(`Forced commit, frameCount: ${this.frameCount}`);
}
return (
this.frameCount === this.maxCommitInterval() ||
(this.frameEMA <= Settings.STREAM_COMMIT_EMA_THRESHOLD &&
this.frameCount >= this.minCommitInterval())
);
}
private ema(period: number): (value: number) => number {
const k = 2 / (period + 1);
let emaPrev = null;
return function update(value: number): number {
if (emaPrev === null) {
emaPrev = value; // initialize with first value
} else {
emaPrev = value * k + emaPrev * (1 - k);
}
return emaPrev;
};
}
private minCommitInterval(): number {
return Settings.STREAM_COMMIT_MIN_INTERVAL;
}
private maxCommitInterval(): number {
return this.useVAD
? Settings.VAD_COMMIT_INTERVAL
: Settings.STREAM_COMMIT_MAX_INTERVAL;
}
}
/**
* Implements real-time transcription of an audio stream sourced from a WebAudio-compliant MediaStream object.
*
* Read more about working with MediaStreams: {@link https://developer.mozilla.org/en-US/docs/Web/API/MediaStream}
*/
class Transcriber {
private static models: Map<string, MoonshineModel> = new Map();
private sttModel: MoonshineModel;
private vadModel: AudioNodeVAD;
callbacks: TranscriberCallbacks;
private useVAD: boolean;
private mediaStream: MediaStream;
private speechBuffer: SpeechBuffer;
protected audioContext: AudioContext;
public isActive: boolean = false;
/**
* Creates a transcriber for transcribing a MediaStream from any source. After creating the {@link Transcriber}, you must invoke
* {@link Transcriber.attachStream} to provide a MediaStream that you want to transcribe.
*
* @param modelURL The URL that the underlying {@link MoonshineModel} weights should be loaded from,
* relative to {@link Settings.BASE_ASSET_PATH.MOONSHINE}.
*
* @param callbacks A set of {@link TranscriberCallbacks} used to trigger behavior at different steps of the
* transcription lifecycle. For transcription-only use cases, you should define the {@link TranscriberCallbacks} yourself;
* when using the transcriber for voice control, you should create a {@link VoiceController} and pass it in.
*
* @param useVAD A boolean specifying whether or not to use Voice Activity Detection (VAD) for deciding when to perform transcriptions.
* When set to `true`, the transcriber will only process speech at the end of each chunk of voice activity; when set to `false`, the transcriber will
* operate in streaming mode, generating continuous transcriptions on a rapid interval.
*
* @example
* This basic example demonstrates the use of the transcriber with custom callbacks:
*
* ``` ts
* import Transcriber from "@moonshine-ai/moonshine-js";
*
* var transcriber = new Transcriber(
* "model/tiny",
* {
* onModelLoadStarted() {
* console.log("onModelLoadStarted()");
* },
* onTranscribeStarted() {
* console.log("onTranscribeStarted()");
* },
* onTranscribeStopped() {
* console.log("onTranscribeStopped()");
* },
* onTranscriptionUpdated(text: string | undefined) {
* console.log(
* "onTranscriptionUpdated(" + text + ")"
* );
* },
* onTranscriptionCommitted(text: string | undefined) {
* console.log(
* "onTranscriptionCommitted(" + text + ")"
* );
* },
* },
* false // use streaming mode
* );
*
* // Get a MediaStream from somewhere (user mic, active tab, an <audio> element, WebRTC source, etc.)
* ...
*
* transcriber.attachStream(stream);
* transcriber.start();
* ```
*/
public constructor(
modelURL: string,
callbacks: Partial<TranscriberCallbacks> = {},
useVAD: boolean = true,
precision: string = "quantized"
) {
this.callbacks = { ...defaultTranscriberCallbacks, ...callbacks };
// we want to avoid re-downloading the same model weights if we can avoid it
// so we only create a new model of the requested type if it hasn't been already
if (!Transcriber.models.has(modelURL))
Transcriber.models.set(modelURL, new MoonshineModel(modelURL, precision));
this.sttModel = Transcriber.models.get(modelURL);
this.useVAD = useVAD;
this.audioContext = new AudioContext();
}
/**
* Preloads the models and initializes the buffer required for transcription.
*/
public async load(): Promise<void> {
this.callbacks.onModelLoadStarted();
try {
await this.sttModel.loadModel();
} catch (err) {
this.callbacks.onError(MoonshineError.PlatformUnsupported);
throw err;
}
// behavior
// useVAD: commit every 30s or onSpeechEnd
// !useVAD: update every updateInterval frames; commit on detected pause (w/ EMA below threshold) occurring between min and max interval OR on max.
this.speechBuffer = new SpeechBuffer(this.useVAD);
var isTalking = false;
const onFrameProcessed = (p, frame) => {
this.speechBuffer.updateEMA(p);
this.callbacks.onFrame(p, frame, this.speechBuffer.frameEMA);
if (isTalking) {
if (this.speechBuffer.shouldSet()) {
this.speechBuffer.set(frame);
}
if (this.speechBuffer.hasFrames()) {
// update
if (
!this.useVAD &&
this.speechBuffer.shouldUpdate() &&
!this.speechBuffer.shouldCommit()
) {
this.sttModel
?.generate(this.speechBuffer.subarray())
.then((text) => {
this.callbacks.onTranscriptionUpdated(text);
})
.catch((err) => {
Log.error("Generation misfire: " + err);
});
}
// commit
else if (this.speechBuffer.shouldCommit()) {
// in this case we need to copy the buffer so that it doesn't get cleared before the inference happens
var tmpBuffer = this.speechBuffer.copy();
this.sttModel
?.generate(tmpBuffer)
.then((text) => {
// buffer is about to be cleared; commit the transcript
if (text) {
this.callbacks.onTranscriptionCommitted(
text,
this.getAudioBuffer(tmpBuffer)
);
}
})
.catch((err) => {
Log.error("Generation misfire: " + err);
});
}
}
if (this.speechBuffer.shouldCommit()) {
// clear buffer (leave some overhang?)
this.speechBuffer.flush();
}
}
};
this.vadModel = await AudioNodeVAD.new(this.audioContext, {
onFrameProcessed: onFrameProcessed,
onVADMisfire: () => {
Log.log("Transcriber.onVADMisfire()");
},
onSpeechStart: () => {
Log.log("Transcriber.onSpeechStart()");
this.callbacks.onSpeechStart();
isTalking = true;
},
onSpeechEnd: (floatArray) => {
Log.log("Transcriber.onSpeechEnd()");
this.callbacks.onSpeechEnd();
var tmpBuffer = this.speechBuffer.copy();
this.sttModel?.generate(tmpBuffer).then((text) => {
if (text) {
this.callbacks.onTranscriptionCommitted(
text,
this.getAudioBuffer(tmpBuffer)
);
}
});
this.speechBuffer.flush();
isTalking = false;
},
model: "v5",
baseAssetPath: Settings.BASE_ASSET_PATH.SILERO_VAD,
onnxWASMBasePath: Settings.BASE_ASSET_PATH.ONNX_RUNTIME,
});
this.attachStream(this.mediaStream);
this.callbacks.onModelLoaded();
}
/**
* Attaches a MediaStream to this {@link Transcriber} for transcription. A MediaStream must be attached before
* starting transcription.
*
* @param stream A MediaStream to transcribe
*/
public attachStream(stream: MediaStream) {
if (stream) {
if (this.vadModel) {
var sourceNode = new MediaStreamAudioSourceNode(
this.audioContext,
{
mediaStream: stream,
}
);
this.vadModel.receive(sourceNode);
Log.log(
"Transcriber.attachStream(): VAD set to receive source node from stream."
);
} else {
// save stream to attach later, after loading
this.mediaStream = stream;
}
}
}
/**
* Detaches the MediaStream used for transcription.
* TODO
*/
public detachStream() {
// TODO
}
/**
* Returns the most recent AudioBuffer that was input to the underlying model for text generation. This is useful in cases where
* we want to double-check the audio being input to the model while debugging.
*
* @returns An AudioBuffer
*/
public getAudioBuffer(buffer: Float32Array): AudioBuffer {
const numChannels = 1;
const audioBuffer = this.audioContext.createBuffer(
numChannels,
buffer.length,
16000
);
audioBuffer.getChannelData(0).set(buffer);
return audioBuffer;
}
/**
* Starts transcription.
*
* Transcription will stop when {@link stop} is called.
*
* Note that the {@link Transcriber} must have a MediaStream attached via {@link Transcriber.attachStream} before
* starting transcription.
*/
public async start() {
if (!this.isActive) {
this.isActive = true;
// load model if not loaded
if (
(!this.sttModel.isLoaded() && !this.sttModel.isLoading()) ||
this.vadModel === undefined
) {
await this.load();
}
this.callbacks.onTranscribeStarted();
this.vadModel.start();
this.audioContext.resume();
setTimeout(() => {
if (this.audioContext.state === "suspended") {
Log.warn(
"AudioContext is suspended, this usually happens on Chrome when you start trying to access an audio source (like a microphone or video) before the user has interacted with the page. Chrome blocks access until there has been a user gesture, so you'll need to rework your code to call start() after an interaction."
);
}
}, 1000);
}
}
/**
* Stops transcription.
*/
public stop() {
this.isActive = false;
this.callbacks.onTranscribeStopped();
if (this.vadModel) {
this.vadModel.pause();
}
}
}
export { Transcriber, TranscriberCallbacks };