whisper.rn
Version:
React Native binding of whisper.cpp
283 lines (244 loc) • 7.54 kB
text/typescript
import { base64ToUint8Array } from './common'
export interface WavFileReaderFs {
readFile: (filePath: string, encoding: string) => Promise<string>
exists: (filePath: string) => Promise<boolean>
unlink: (filePath: string) => Promise<void>
}
export interface WavFileHeader {
sampleRate: number
channels: number
bitsPerSample: number
dataSize: number
duration: number
}
type ParsedWavFileHeader = {
header: WavFileHeader
dataOffset: number
}
export class WavFileReader {
private filePath: string
private header: WavFileHeader | null = null
private audioData: Uint8Array | null = null
private fs: {
exists: (filePath: string) => Promise<boolean>
readFile: (filePath: string, encoding: string) => Promise<string>
}
constructor(fs: {
exists: (filePath: string) => Promise<boolean>
readFile: (filePath: string, encoding: string) => Promise<string>
}, filePath: string) {
this.fs = fs
this.filePath = filePath
}
/**
* Read and parse the WAV file
*/
async initialize(): Promise<void> {
try {
// Check if file exists
const exists = await this.fs.exists(this.filePath)
if (!exists) {
throw new Error(`WAV file not found: ${this.filePath}`)
}
// Read the entire file
const fileContent = await this.fs.readFile(this.filePath, 'base64')
const fileData = base64ToUint8Array(fileContent)
// Parse WAV chunks and extract audio from the actual data chunk.
const parsedHeader = WavFileReader.parseWavHeader(fileData)
this.header = parsedHeader.header
this.audioData = fileData.slice(
parsedHeader.dataOffset,
parsedHeader.dataOffset + this.header.dataSize,
)
console.log(
`WAV file loaded: ${this.header.duration.toFixed(2)}s, ${
this.header.sampleRate
}Hz, ${this.header.channels}ch`,
)
} catch (error) {
throw new Error(`Failed to initialize WAV file reader: ${error}`)
}
}
/**
* Parse WAV file header
*/
private static parseWavHeader(data: Uint8Array): ParsedWavFileHeader {
const view = new DataView(data.buffer, data.byteOffset, data.byteLength)
// Verify RIFF header
const riffHeader = WavFileReader.readChunkId(data, 0)
if (riffHeader !== 'RIFF') {
throw new Error('Invalid WAV file: Missing RIFF header')
}
// Verify WAVE format
const waveHeader = WavFileReader.readChunkId(data, 8)
if (waveHeader !== 'WAVE') {
throw new Error('Invalid WAV file: Missing WAVE header')
}
let channels = 0
let sampleRate = 0
let bitsPerSample = 0
let isPcm = false
let hasFmtChunk = false
let dataOffset = 0
let dataSize = 0
let offset = 12
while (offset + 8 <= data.length) {
const chunkId = WavFileReader.readChunkId(data, offset)
const chunkSize = view.getUint32(offset + 4, true)
const chunkDataOffset = offset + 8
if (chunkDataOffset > data.length) {
throw new Error('Invalid WAV file: Malformed chunk')
}
const availableBytes = data.length - chunkDataOffset
const chunkExceedsFile = chunkSize > availableBytes
if (chunkExceedsFile && chunkId !== 'data') {
throw new Error('Invalid WAV file: Malformed chunk')
}
const effectiveChunkSize = chunkExceedsFile ? availableBytes : chunkSize
if (chunkId === 'fmt ') {
if (chunkSize < 16) {
throw new Error('Invalid WAV file: Malformed fmt chunk')
}
const audioFormat = view.getUint16(chunkDataOffset, true)
channels = view.getUint16(chunkDataOffset + 2, true)
sampleRate = view.getUint32(chunkDataOffset + 4, true)
bitsPerSample = view.getUint16(chunkDataOffset + 14, true)
isPcm =
audioFormat === 1 ||
(audioFormat === 0xfffe &&
WavFileReader.hasPcmExtensibleSubFormat(
data,
chunkDataOffset,
chunkSize,
))
hasFmtChunk = true
} else if (chunkId === 'data') {
dataOffset = chunkDataOffset
dataSize = effectiveChunkSize
if (hasFmtChunk) break
}
let nextOffset = chunkDataOffset + effectiveChunkSize
if (!chunkExceedsFile && chunkSize % 2 !== 0 && nextOffset < data.length) {
nextOffset += 1
}
if (nextOffset <= offset) {
throw new Error('Invalid WAV file: Malformed chunk')
}
offset = nextOffset
}
if (!hasFmtChunk) {
throw new Error('Invalid WAV file: Missing fmt chunk')
}
if (!dataOffset) {
throw new Error('Invalid WAV file: Missing data chunk')
}
if (!isPcm) {
throw new Error('Unsupported WAV format: Only PCM is supported')
}
if (!channels) {
throw new Error('Invalid WAV file: Invalid channel count')
}
if (!sampleRate) {
throw new Error('Invalid WAV file: Invalid sample rate')
}
const duration = dataSize / (sampleRate * channels * (bitsPerSample / 8))
return {
header: {
sampleRate,
channels,
bitsPerSample,
dataSize,
duration,
},
dataOffset,
}
}
private static readChunkId(data: Uint8Array, offset: number): string {
if (offset + 4 > data.length) return ''
return String.fromCharCode(
data[offset] ?? 0,
data[offset + 1] ?? 0,
data[offset + 2] ?? 0,
data[offset + 3] ?? 0,
)
}
private static hasPcmExtensibleSubFormat(
data: Uint8Array,
fmtDataOffset: number,
chunkSize: number,
): boolean {
const pcmSubFormatGuid = [
0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x10, 0x00, 0x80, 0x00, 0x00, 0xaa,
0x00, 0x38, 0x9b, 0x71,
]
const subFormatOffset = fmtDataOffset + 24
if (chunkSize < 40 || subFormatOffset + pcmSubFormatGuid.length > data.length) {
return false
}
return pcmSubFormatGuid.every(
(value, index) => data[subFormatOffset + index] === value,
)
}
/**
* Get audio data slice
*/
getAudioSlice(startByte: number, lengthBytes: number): Uint8Array | null {
if (!this.audioData) {
return null
}
const start = Math.max(0, startByte)
const end = Math.min(this.audioData.length, startByte + lengthBytes)
if (start >= end) {
return null
}
return this.audioData.slice(start, end)
}
getAudioData(): Uint8Array | null {
return this.audioData
}
/**
* Get WAV file header information
*/
getHeader(): WavFileHeader | null {
return this.header
}
/**
* Get total audio data size
*/
getTotalDataSize(): number {
return this.header?.dataSize || 0
}
/**
* Convert byte position to time in seconds
*/
byteToTime(bytePosition: number): number {
if (!this.header) return 0
const bytesPerSecond =
this.header.sampleRate *
this.header.channels *
(this.header.bitsPerSample / 8)
return bytePosition / bytesPerSecond
}
/**
* Convert time in seconds to byte position
*/
timeToByte(timeSeconds: number): number {
if (!this.header) return 0
const bytesPerSecond =
this.header.sampleRate *
this.header.channels *
(this.header.bitsPerSample / 8)
return Math.floor(timeSeconds * bytesPerSecond)
}
/**
* Get file statistics
*/
getStatistics() {
return {
filePath: this.filePath,
header: this.header,
totalDataSize: this.getTotalDataSize(),
isInitialized: !!this.header,
}
}
}