@nsfw-filter/nsfwjs
Version:
Detect NSFW content client side
80 lines (67 loc) • 2.35 kB
text/typescript
import { load, predictionType, NSFWJS } from '../src/index'
const fs = require('fs');
// Fix for JEST
const globalAny: any = global
globalAny.fetch = require('node-fetch')
const timeoutMS = 10000
const path = `${__dirname}/../example/manual-testing/data/animations/smile.gif`
const roundPredicitonProbability = ({ className, probability }: predictionType) => {
return {className, probability: Math.floor(probability * 10000) / 10000}
}
describe('NSFWJS classify GIF', () => {
let model: NSFWJS
let buffer: Buffer
beforeAll(async () => {
model = await load()
buffer = fs.readFileSync(path)
});
it("Probabilities match", async () => {
const expectedResults = [
[
{ className: 'Neutral', probability: 0.8766 },
{ className: 'Porn', probability: 0.091 },
{ className: 'Sexy', probability: 0.0316 }
],
[
{ className: 'Neutral', probability: 0.8995 },
{ className: 'Porn', probability: 0.0511 },
{ className: 'Sexy', probability: 0.0487 }
],
[
{ className: 'Neutral', probability: 0.8541 },
{ className: 'Sexy', probability: 0.1027 },
{ className: 'Porn', probability: 0.0424 }
]
]
const predictions = await model.classifyGif(buffer, { topk: 3, fps: 0.4 })
expect(predictions.length).toBe(3)
let index = 0
predictions[index].map((actualObj, id) => {
expect(roundPredicitonProbability(actualObj)).toEqual(expectedResults[index][id])
})
index = 1
predictions[index].map((actualObj, id) => {
expect(roundPredicitonProbability(actualObj)).toEqual(expectedResults[index][id])
})
index = 2
predictions[index].map((actualObj, id) => {
expect(roundPredicitonProbability(actualObj)).toEqual(expectedResults[index][id])
})
},
timeoutMS
)
it("0 fps - single frame from the middle", async () => {
const predictions = await model.classifyGif(buffer, { topk: 3, fps: 0 })
expect(predictions.length).toBe(1)
},
timeoutMS
)
// Takes too long
it.skip("All frames", async () => {
const predictions = await model.classifyGif(buffer, { topk: 3 })
console.log(predictions.length)
expect(predictions.length).toBe(190)
},
timeoutMS
)
})