whisper.rn
Version:
React Native binding of whisper.cpp
585 lines (537 loc) • 17.7 kB
text/typescript
import {
NativeEventEmitter,
DeviceEventEmitter,
Platform,
DeviceEventEmitterStatic,
Image,
} from 'react-native'
import RNWhisper, { NativeWhisperContext } from './NativeRNWhisper'
import type {
TranscribeOptions,
TranscribeResult,
CoreMLAsset,
} from './NativeRNWhisper'
import AudioSessionIos from './AudioSessionIos'
import type {
AudioSessionCategoryIos,
AudioSessionCategoryOptionIos,
AudioSessionModeIos,
} from './AudioSessionIos'
import { version } from './version.json'
let EventEmitter: NativeEventEmitter | DeviceEventEmitterStatic
if (Platform.OS === 'ios') {
// @ts-ignore
EventEmitter = new NativeEventEmitter(RNWhisper)
}
if (Platform.OS === 'android') {
EventEmitter = DeviceEventEmitter
}
export type {
TranscribeOptions,
TranscribeResult,
AudioSessionCategoryIos,
AudioSessionCategoryOptionIos,
AudioSessionModeIos,
}
const EVENT_ON_TRANSCRIBE_PROGRESS = '@RNWhisper_onTranscribeProgress'
const EVENT_ON_TRANSCRIBE_NEW_SEGMENTS = '@RNWhisper_onTranscribeNewSegments'
const EVENT_ON_REALTIME_TRANSCRIBE = '@RNWhisper_onRealtimeTranscribe'
const EVENT_ON_REALTIME_TRANSCRIBE_END = '@RNWhisper_onRealtimeTranscribeEnd'
export type TranscribeNewSegmentsResult = {
nNew: number
totalNNew: number
result: string
segments: TranscribeResult['segments']
}
export type TranscribeNewSegmentsNativeEvent = {
contextId: number
jobId: number
result: TranscribeNewSegmentsResult
}
// Fn -> Boolean in TranscribeFileNativeOptions
export type TranscribeFileOptions = TranscribeOptions & {
/**
* Progress callback, the progress is between 0 and 100
*/
onProgress?: (progress: number) => void
/**
* Callback when new segments are transcribed
*/
onNewSegments?: (result: TranscribeNewSegmentsResult) => void
}
export type TranscribeProgressNativeEvent = {
contextId: number
jobId: number
progress: number
}
export type AudioSessionSettingIos = {
category: AudioSessionCategoryIos
options?: AudioSessionCategoryOptionIos[]
mode?: AudioSessionModeIos
active?: boolean
}
// Codegen missing TSIntersectionType support so we dont put it into the native spec
export type TranscribeRealtimeOptions = TranscribeOptions & {
/**
* Realtime record max duration in seconds.
* Due to the whisper.cpp hard constraint - processes the audio in chunks of 30 seconds,
* the recommended value will be <= 30 seconds. (Default: 30)
*/
realtimeAudioSec?: number
/**
* Optimize audio transcription performance by slicing audio samples when `realtimeAudioSec` > 30.
* Set `realtimeAudioSliceSec` < 30 so performance improvements can be achieved in the Whisper hard constraint (processes the audio in chunks of 30 seconds).
* (Default: Equal to `realtimeMaxAudioSec`)
*/
realtimeAudioSliceSec?: number
/**
* Min duration of audio to start transcribe in seconds for each slice.
* The minimum value is 0.5 ms and maximum value is realtimeAudioSliceSec (Default: 1)
*/
realtimeAudioMinSec?: number
/**
* Output path for audio file. If not set, the audio file will not be saved
* (Default: Undefined)
*/
audioOutputPath?: string
/**
* Start transcribe on recording when the audio volume is greater than the threshold by using VAD (Voice Activity Detection).
* The first VAD will be triggered after 2 second of recording.
* (Default: false)
*/
useVad?: boolean
/**
* The length of the collected audio is used for VAD, cannot be less than 2000ms. (ms) (Default: 2000)
*/
vadMs?: number
/**
* VAD threshold. (Default: 0.6)
*/
vadThold?: number
/**
* Frequency to apply High-pass filter in VAD. (Default: 100.0)
*/
vadFreqThold?: number
/**
* iOS: Audio session settings when start transcribe
* Keep empty to use current audio session state
*/
audioSessionOnStartIos?: AudioSessionSettingIos
/**
* iOS: Audio session settings when stop transcribe
* - Keep empty to use last audio session state
* - Use `restore` to restore audio session state before start transcribe
*/
audioSessionOnStopIos?: string | AudioSessionSettingIos
}
export type TranscribeRealtimeEvent = {
contextId: number
jobId: number
/** Is capturing audio, when false, the event is the final result */
isCapturing: boolean
isStoppedByAction?: boolean
code: number
data?: TranscribeResult
error?: string
processTime: number
recordingTime: number
slices?: Array<{
code: number
error?: string
data?: TranscribeResult
processTime: number
recordingTime: number
}>
}
export type TranscribeRealtimeNativePayload = {
/** Is capturing audio, when false, the event is the final result */
isCapturing: boolean
isStoppedByAction?: boolean
code: number
processTime: number
recordingTime: number
isUseSlices: boolean
sliceIndex: number
data?: TranscribeResult
error?: string
}
export type TranscribeRealtimeNativeEvent = {
contextId: number
jobId: number
payload: TranscribeRealtimeNativePayload
}
export type BenchResult = {
config: string
nThreads: number
encodeMs: number
decodeMs: number
batchMs: number
promptMs: number
}
const updateAudioSession = async (setting: AudioSessionSettingIos) => {
await AudioSessionIos.setCategory(
setting.category,
setting.options || [],
)
if (setting.mode) {
await AudioSessionIos.setMode(setting.mode)
}
await AudioSessionIos.setActive(setting.active ?? true)
}
export class WhisperContext {
id: number
gpu: boolean = false
reasonNoGPU: string = ''
constructor({
contextId,
gpu,
reasonNoGPU,
}: NativeWhisperContext) {
this.id = contextId
this.gpu = gpu
this.reasonNoGPU = reasonNoGPU
}
private transcribeWithNativeMethod(method: 'transcribeFile' | 'transcribeData', data: string, options: TranscribeFileOptions = {}): {
stop: () => Promise<void>
promise: Promise<TranscribeResult>
} {
const jobId: number = Math.floor(Math.random() * 10000)
const { onProgress, onNewSegments, ...rest } = options
let progressListener: any
let lastProgress: number = 0
if (onProgress) {
progressListener = EventEmitter.addListener(
EVENT_ON_TRANSCRIBE_PROGRESS,
(evt: TranscribeProgressNativeEvent) => {
const { contextId, progress } = evt
if (contextId !== this.id || evt.jobId !== jobId) return
lastProgress = progress > 100 ? 100 : progress
onProgress(lastProgress)
},
)
}
const removeProgressListener = () => {
if (progressListener) {
progressListener.remove()
progressListener = null
}
}
let newSegmentsListener: any
if (onNewSegments) {
newSegmentsListener = EventEmitter.addListener(
EVENT_ON_TRANSCRIBE_NEW_SEGMENTS,
(evt: TranscribeNewSegmentsNativeEvent) => {
const { contextId, result } = evt
if (contextId !== this.id || evt.jobId !== jobId) return
onNewSegments(result)
},
)
}
const removeNewSegmenetsListener = () => {
if (newSegmentsListener) {
newSegmentsListener.remove()
newSegmentsListener = null
}
}
return {
stop: async () => {
await RNWhisper.abortTranscribe(this.id, jobId)
removeProgressListener()
removeNewSegmenetsListener()
},
promise: RNWhisper[method](this.id, jobId, data, {
...rest,
onProgress: !!onProgress,
onNewSegments: !!onNewSegments,
})
.then((result) => {
removeProgressListener()
removeNewSegmenetsListener()
if (!result.isAborted && lastProgress !== 100) {
// Handle the case that the last progress event is not triggered
onProgress?.(100)
}
return result
})
.catch((e) => {
removeProgressListener()
removeNewSegmenetsListener()
throw e
}),
}
}
/**
* Transcribe audio file (path or base64 encoded wav file)
* base64: need add `data:audio/wav;base64,` prefix
*/
transcribe(
filePathOrBase64: string | number,
options: TranscribeFileOptions = {},
): {
/** Stop the transcribe */
stop: () => Promise<void>
/** Transcribe result promise */
promise: Promise<TranscribeResult>
} {
let path = ''
if (typeof filePathOrBase64 === 'number') {
try {
const source = Image.resolveAssetSource(filePathOrBase64)
if (source) path = source.uri
} catch (e) {
throw new Error(`Invalid asset: ${filePathOrBase64}`)
}
} else {
if (filePathOrBase64.startsWith('http'))
throw new Error(
'Transcribe remote file is not supported, please download it first',
)
path = filePathOrBase64
}
if (path.startsWith('file://')) path = path.slice(7)
return this.transcribeWithNativeMethod('transcribeFile', path, options)
}
/**
* Transcribe audio data (base64 encoded float32 PCM data)
*/
transcribeData(data: string, options: TranscribeFileOptions = {}): {
stop: () => Promise<void>
promise: Promise<TranscribeResult>
} {
return this.transcribeWithNativeMethod('transcribeData', data, options)
}
/** Transcribe the microphone audio stream, the microphone user permission is required */
async transcribeRealtime(options: TranscribeRealtimeOptions = {}): Promise<{
/** Stop the realtime transcribe */
stop: () => Promise<void>
/** Subscribe to realtime transcribe events */
subscribe: (callback: (event: TranscribeRealtimeEvent) => void) => void
}> {
let lastTranscribePayload: TranscribeRealtimeNativePayload
const slices: TranscribeRealtimeNativePayload[] = []
let sliceIndex: number = 0
let tOffset: number = 0
const putSlice = (payload: TranscribeRealtimeNativePayload) => {
if (!payload.isUseSlices || !payload.data) return
if (sliceIndex !== payload.sliceIndex) {
const { segments = [] } = slices[sliceIndex]?.data || {}
tOffset = segments[segments.length - 1]?.t1 || 0
}
;({ sliceIndex } = payload)
slices[sliceIndex] = {
...payload,
data: {
...payload.data,
segments:
payload.data.segments.map((segment) => ({
...segment,
t0: segment.t0 + tOffset,
t1: segment.t1 + tOffset,
})) || [],
}
}
}
const mergeSlicesIfNeeded = (
payload: TranscribeRealtimeNativePayload,
): TranscribeRealtimeNativePayload => {
if (!payload.isUseSlices) return payload
const mergedPayload: any = {}
slices.forEach((slice) => {
mergedPayload.data = {
result:
(mergedPayload.data?.result || '') + (slice.data?.result || ''),
segments: [
...(mergedPayload?.data?.segments || []),
...(slice.data?.segments || []),
],
}
mergedPayload.processTime = slice.processTime
mergedPayload.recordingTime =
(mergedPayload?.recordingTime || 0) + slice.recordingTime
})
return { ...payload, ...mergedPayload, slices }
}
let prevAudioSession: AudioSessionSettingIos | undefined
if (Platform.OS === 'ios' && options?.audioSessionOnStartIos) {
// iOS: Remember current audio session state
if (options?.audioSessionOnStopIos === 'restore') {
const categoryResult = await AudioSessionIos.getCurrentCategory()
const mode = await AudioSessionIos.getCurrentMode()
prevAudioSession = {
...categoryResult,
mode,
active: false, // TODO: Need to check isOtherAudioPlaying to set active
}
}
// iOS: Update audio session state
await updateAudioSession(options?.audioSessionOnStartIos)
}
if (Platform.OS === 'ios' && typeof options?.audioSessionOnStopIos === 'object') {
prevAudioSession = options?.audioSessionOnStopIos
}
const jobId: number = Math.floor(Math.random() * 10000)
try {
await RNWhisper.startRealtimeTranscribe(this.id, jobId, options)
} catch (e) {
if (prevAudioSession) await updateAudioSession(prevAudioSession)
throw e
}
return {
stop: async () => {
await RNWhisper.abortTranscribe(this.id, jobId)
if (prevAudioSession) await updateAudioSession(prevAudioSession)
},
subscribe: (callback: (event: TranscribeRealtimeEvent) => void) => {
let transcribeListener: any = EventEmitter.addListener(
EVENT_ON_REALTIME_TRANSCRIBE,
(evt: TranscribeRealtimeNativeEvent) => {
const { contextId, payload } = evt
if (contextId !== this.id || evt.jobId !== jobId) return
lastTranscribePayload = payload
putSlice(payload)
callback({
contextId,
jobId: evt.jobId,
...mergeSlicesIfNeeded(payload),
})
},
)
let endListener: any = EventEmitter.addListener(
EVENT_ON_REALTIME_TRANSCRIBE_END,
(evt: TranscribeRealtimeNativeEvent) => {
const { contextId, payload } = evt
if (contextId !== this.id || evt.jobId !== jobId) return
const lastPayload = {
...lastTranscribePayload,
...payload,
}
putSlice(lastPayload)
callback({
contextId,
jobId: evt.jobId,
...mergeSlicesIfNeeded(lastPayload),
isCapturing: false,
})
if (transcribeListener) {
transcribeListener.remove()
transcribeListener = null
}
if (endListener) {
endListener.remove()
endListener = null
}
},
)
},
}
}
async bench(maxThreads: number): Promise<BenchResult> {
const result = await RNWhisper.bench(this.id, maxThreads)
const [config, nThreads, encodeMs, decodeMs, batchMs, promptMs] = JSON.parse(result)
return { config, nThreads, encodeMs, decodeMs, batchMs, promptMs } as BenchResult
}
async release(): Promise<void> {
return RNWhisper.releaseContext(this.id)
}
}
export type ContextOptions = {
filePath: string | number
/**
* CoreML model assets, if you're using `require` on filePath,
* use this option is required if you want to enable Core ML,
* you will need bundle weights/weight.bin, model.mil, coremldata.bin into app by `require`
*/
coreMLModelAsset?: {
filename: string
assets: string[] | number[]
}
/** Is the file path a bundle asset for pure string filePath */
isBundleAsset?: boolean
/** Prefer to use Core ML model if exists. If set to false, even if the Core ML model exists, it will not be used. */
useCoreMLIos?: boolean
/** Use GPU if available. Currently iOS only, if it's enabled, Core ML option will be ignored. */
useGpu?: boolean
/** Use Flash Attention, only recommended if GPU available */
useFlashAttn?: boolean,
}
const coreMLModelAssetPaths = [
'analytics/coremldata.bin',
'weights/weight.bin',
'model.mil',
'coremldata.bin',
]
export async function initWhisper({
filePath,
coreMLModelAsset,
isBundleAsset,
useGpu = true,
useCoreMLIos = true,
useFlashAttn = false,
}: ContextOptions): Promise<WhisperContext> {
let path = ''
let coreMLAssets: CoreMLAsset[] | undefined
if (coreMLModelAsset) {
const { filename, assets } = coreMLModelAsset
if (filename && assets) {
coreMLAssets = assets
?.map((asset) => {
if (typeof asset === 'number') {
const { uri } = Image.resolveAssetSource(asset)
const filepath = coreMLModelAssetPaths.find((p) => uri.includes(p))
if (filepath) {
return {
uri,
filepath: `${filename}/${filepath}`,
}
}
} else if (typeof asset === 'string') {
return {
uri: asset,
filepath: `${filename}/${asset}`,
}
}
return undefined
})
.filter((asset): asset is CoreMLAsset => asset !== undefined)
}
}
if (typeof filePath === 'number') {
try {
const source = Image.resolveAssetSource(filePath)
if (source) {
path = source.uri
}
} catch (e) {
throw new Error(`Invalid asset: ${filePath}`)
}
} else {
if (!isBundleAsset && filePath.startsWith('http'))
throw new Error(
'Transcribe remote file is not supported, please download it first',
)
path = filePath
}
if (path.startsWith('file://')) path = path.slice(7)
const { contextId, gpu, reasonNoGPU } = await RNWhisper.initContext({
filePath: path,
isBundleAsset: !!isBundleAsset,
useFlashAttn,
useGpu,
useCoreMLIos,
// Only development mode need download Core ML model assets (from packager server)
downloadCoreMLAssets: __DEV__ && !!coreMLAssets,
coreMLAssets,
})
return new WhisperContext({ contextId, gpu, reasonNoGPU })
}
export async function releaseAllWhisper(): Promise<void> {
return RNWhisper.releaseAllContexts()
}
/** Current version of whisper.cpp */
export const libVersion: string = version
const { useCoreML, coreMLAllowFallback } = RNWhisper.getConstants?.() || {}
/** Is use CoreML models on iOS */
export const isUseCoreML: boolean = !!useCoreML
/** Is allow fallback to CPU if load CoreML model failed */
export const isCoreMLAllowFallback: boolean = !!coreMLAllowFallback
export { AudioSessionIos }