UNPKG

@remove-background-ai/rembg.js

Version:

A simple wrapper for the https://www.rembg.com API

548 lines (450 loc) 16 kB
import axios from "axios"; import MockAdapter from 'axios-mock-adapter'; import * as tmp from 'tmp-promise'; const FormData = require('form-data'); const { rembg } = require('./../index'); const mockAxios = new MockAdapter(axios); jest.mock('tmp-promise', () => ({ file: jest.fn().mockResolvedValue({ path: 'path/to/output.png', cleanup: jest.fn() }) })); jest.mock('fs', () => ({ createReadStream: jest.fn().mockReturnValue("stream"), writeFileSync: jest.fn() })); describe('rembg', () => { beforeEach(() => { mockAxios.reset(); jest.clearAllMocks(); }); it('should throw an error if apiKey is not provided', async () => { await expect(rembg({ apiKey: '', inputImage: 'path/to/image.png', })).rejects.toThrowError('⚠️⚠️⚠️ WARNING ⚠️⚠️⚠️: API key not provided, trials will be very limited.'); }); it('should return base64 image if returnBase64 is true', async () => { // Mock the axios request const axiosMock = jest.spyOn(axios, 'request').mockResolvedValueOnce({ data: Buffer.from('image data'), }); const result = await rembg({ apiKey: 'your-api-key', inputImage: 'path/to/image.png', options: { returnBase64: true, }, }); expect(result).toEqual({ base64Image: '', }); expect(axiosMock).toHaveBeenCalledWith(expect.objectContaining({ responseType: 'arraybuffer', })); }); it('should return output image path and cleanup function if returnBase64 is false', async () => { // Mock the axios request const axiosMock = jest.spyOn(axios, 'request').mockResolvedValueOnce({ data: Buffer.from('image data'), }); const result = await rembg({ apiKey: 'your-api-key', inputImage: 'path/to/image.png', options: { returnBase64: false, format: "png" }, }); expect(result).toEqual({ outputImagePath: 'path/to/output.png', cleanup: expect.any(Function), }); expect(axiosMock).toHaveBeenCalledWith(expect.objectContaining({ responseType: 'arraybuffer', })); expect(tmp.file).toHaveBeenCalledWith(expect.objectContaining({ prefix: 'rembg-', postfix: '.png', })); axiosMock.mockRestore(); }); it('should return the right format=WEBP', async () => { // Mock the axios request const axiosMock = jest.spyOn(axios, 'request').mockResolvedValueOnce({ data: Buffer.from('image data'), }); const result = await rembg({ apiKey: 'your-api-key', inputImage: 'path/to/image.png', options: { returnBase64: false, // format: "png" no format selected => default to WEBP }, }); expect(result).toEqual({ outputImagePath: 'path/to/output.png', cleanup: expect.any(Function), }); expect(axiosMock).toHaveBeenCalledWith(expect.objectContaining({ responseType: 'arraybuffer', })); expect(tmp.file).toHaveBeenCalledWith(expect.objectContaining({ prefix: 'rembg-', postfix: '.WEBP', })); axiosMock.mockRestore(); }); it('should return the right format=PNG', async () => { // Mock the axios request const axiosMock = jest.spyOn(axios, 'request').mockResolvedValueOnce({ data: Buffer.from('image data'), }); const result = await rembg({ apiKey: 'your-api-key', inputImage: 'path/to/image.png', options: { returnBase64: false, format: 'PNG' }, }); expect(result).toEqual({ outputImagePath: 'path/to/output.png', cleanup: expect.any(Function), }); expect(axiosMock).toHaveBeenCalledWith(expect.objectContaining({ responseType: 'arraybuffer', })); expect(tmp.file).toHaveBeenCalledWith(expect.objectContaining({ prefix: 'rembg-', postfix: '.PNG', })); axiosMock.mockRestore(); }); it('should throw an error if the request fails', async () => { // Mock the axios request const axiosMock = jest.spyOn(axios, 'request').mockRejectedValueOnce({ message: 'Request failed', }); await expect(rembg({ apiKey: 'your-api-key', inputImage: 'path/to/image.png', })).rejects.toThrowError('❌ Request failed'); expect(axiosMock).toHaveBeenCalledWith(expect.objectContaining({ responseType: 'arraybuffer', })); }); it('should throw an error if the server responds with an error status code', async () => { // Mock the axios request const axiosMock = jest.spyOn(axios, 'request').mockRejectedValueOnce({ response: { status: 500, data: 'Internal Server Error', }, }); await expect(rembg({ apiKey: 'your-api-key', inputImage: 'path/to/image.png', })).rejects.toThrowError('❌ 500 Internal Server Error'); expect(axiosMock).toHaveBeenCalledWith(expect.objectContaining({ responseType: 'arraybuffer', })); }); it('should throw an error if no response is received', async () => { // Mock the axios request const axiosMock = jest.spyOn(axios, 'request').mockRejectedValueOnce({ request: {}, }); await expect(rembg({ apiKey: 'your-api-key', inputImage: 'path/to/image.png', })).rejects.toThrowError(); expect(axiosMock).toHaveBeenCalledWith(expect.objectContaining({ responseType: 'arraybuffer', })); }); // test if a request contains "x-api-key" header with the provided API key it('should send a request with the API key header', async () => { // Mock the axios request const axiosMock = jest.spyOn(axios, 'request').mockResolvedValueOnce({ data: Buffer.from('image data'), }); await rembg({ apiKey: 'your-api-key', inputImage: 'path/to/image.png', }); expect(axiosMock).toHaveBeenCalledWith(expect.objectContaining({ headers: expect.objectContaining({ 'x-api-key': 'your-api-key', }), })); }); it('should send a request with the mask field in the form data', async () => { // Spy on FormData.prototype.append const appendSpy = jest.spyOn(FormData.prototype, 'append'); // Mock the axios request const axiosMock = jest.spyOn(axios, 'request').mockResolvedValueOnce({ data: Buffer.from('image data'), }); await rembg({ apiKey: 'your-api-key', inputImage: 'path/to/image.png', onUploadProgress: () => {}, onDownloadProgress: () => {}, options: { returnMask: true, returnBase64: false }, }); // Check if FormData was called with the 'mask' field expect(appendSpy).toHaveBeenCalledWith('image', expect.anything()); expect(appendSpy).toHaveBeenCalledWith('mask', 'true'); // Check axiosMock to be called with FormData expect(axiosMock).toHaveBeenCalledWith(expect.objectContaining({ data: expect.any(FormData), })); // Restore mocks jest.restoreAllMocks(); }); it('should handle buffer input', async () => { const axiosMock = jest.spyOn(axios, 'request').mockResolvedValueOnce({ data: Buffer.from('processed image data'), }); const inputBuffer = Buffer.from('input image data'); // Mock FormData const mockAppend = jest.fn(); jest.spyOn(FormData.prototype, 'append').mockImplementation(mockAppend); const result = await rembg({ apiKey: 'your-api-key', inputImage: inputBuffer, }); expect(result).toEqual({ outputImagePath: 'path/to/output.png', cleanup: expect.any(Function), }); expect(axiosMock).toHaveBeenCalledWith(expect.objectContaining({ data: expect.any(FormData), })); // Check if the FormData.append was called with the correct arguments expect(mockAppend).toHaveBeenCalledWith('image', inputBuffer, { filename: 'image.png' }); axiosMock.mockRestore(); }); // New tests for base64 input it('should handle base64 input', async () => { const axiosMock = jest.spyOn(axios, 'request').mockResolvedValueOnce({ data: Buffer.from('processed image data'), }); const base64Input = { base64: 'aW5wdXQgaW1hZ2UgZGF0YQ==' }; // "input image data" in base64 const result = await rembg({ apiKey: 'your-api-key', inputImage: base64Input, }); expect(result).toEqual({ outputImagePath: 'path/to/output.png', cleanup: expect.any(Function), }); expect(axiosMock).toHaveBeenCalledWith(expect.objectContaining({ data: expect.any(FormData), })); // Check if the FormData was created with the buffer from base64 const formDataAppendSpy = jest.spyOn(FormData.prototype, 'append'); expect(formDataAppendSpy).toHaveBeenCalledWith( 'image', expect.any(Buffer), { filename: 'image.png', contentType: "image/png"}, ); axiosMock.mockRestore(); formDataAppendSpy.mockRestore(); }); // Test for invalid input type it('should throw an error for invalid input type', async () => { await expect(rembg({ apiKey: 'your-api-key', inputImage: 123, // Invalid input type })).rejects.toThrowError('Invalid input type. Must be a file path, Buffer, or an object with a base64 property.'); }); // Update existing tests to use 'input' instead of 'inputImage' it('should return base64 image if returnBase64 is true', async () => { const axiosMock = jest.spyOn(axios, 'request').mockResolvedValueOnce({ data: Buffer.from('image data'), }); const result = await rembg({ apiKey: 'your-api-key', inputImage: 'path/to/image.png', options: { returnBase64: true, }, }); expect(result).toEqual({ base64Image: '', }); expect(axiosMock).toHaveBeenCalledWith(expect.objectContaining({ responseType: 'arraybuffer', })); axiosMock.mockRestore(); }); // Tests for new parameters: angle, expand, bg_color describe('New parameters (angle, expand, bg_color)', () => { beforeEach(() => { jest.clearAllMocks(); }); it('should send angle parameter when provided', async () => { const appendSpy = jest.spyOn(FormData.prototype, 'append'); const axiosMock = jest.spyOn(axios, 'request').mockResolvedValueOnce({ data: Buffer.from('image data'), }); await rembg({ apiKey: 'your-api-key', inputImage: 'path/to/image.png', onUploadProgress: () => {}, onDownloadProgress: () => {}, options: { angle: 90, }, }); expect(appendSpy).toHaveBeenCalledWith('angle', '90'); axiosMock.mockRestore(); appendSpy.mockRestore(); }); it('should NOT send angle parameter when not provided', async () => { const appendSpy = jest.spyOn(FormData.prototype, 'append'); const axiosMock = jest.spyOn(axios, 'request').mockResolvedValueOnce({ data: Buffer.from('image data'), }); await rembg({ apiKey: 'your-api-key', inputImage: 'path/to/image.png', onUploadProgress: () => {}, onDownloadProgress: () => {}, options: {}, }); expect(appendSpy).not.toHaveBeenCalledWith('angle', expect.anything()); axiosMock.mockRestore(); appendSpy.mockRestore(); }); it('should send expand parameter when explicitly provided as true', async () => { const appendSpy = jest.spyOn(FormData.prototype, 'append'); const axiosMock = jest.spyOn(axios, 'request').mockResolvedValueOnce({ data: Buffer.from('image data'), }); await rembg({ apiKey: 'your-api-key', inputImage: 'path/to/image.png', onUploadProgress: () => {}, onDownloadProgress: () => {}, options: { expand: true, }, }); expect(appendSpy).toHaveBeenCalledWith('expand', 'true'); axiosMock.mockRestore(); appendSpy.mockRestore(); }); it('should send expand parameter when explicitly provided as false', async () => { const appendSpy = jest.spyOn(FormData.prototype, 'append'); const axiosMock = jest.spyOn(axios, 'request').mockResolvedValueOnce({ data: Buffer.from('image data'), }); await rembg({ apiKey: 'your-api-key', inputImage: 'path/to/image.png', onUploadProgress: () => {}, onDownloadProgress: () => {}, options: { expand: false, }, }); expect(appendSpy).toHaveBeenCalledWith('expand', 'false'); axiosMock.mockRestore(); appendSpy.mockRestore(); }); it('should send expand parameter with default value "true" when not provided', async () => { const appendSpy = jest.spyOn(FormData.prototype, 'append'); const axiosMock = jest.spyOn(axios, 'request').mockResolvedValueOnce({ data: Buffer.from('image data'), }); await rembg({ apiKey: 'your-api-key', inputImage: 'path/to/image.png', onUploadProgress: () => {}, onDownloadProgress: () => {}, options: {}, }); expect(appendSpy).toHaveBeenCalledWith('expand', 'true'); axiosMock.mockRestore(); appendSpy.mockRestore(); }); it('should send bg_color parameter when provided', async () => { const appendSpy = jest.spyOn(FormData.prototype, 'append'); const axiosMock = jest.spyOn(axios, 'request').mockResolvedValueOnce({ data: Buffer.from('image data'), }); await rembg({ apiKey: 'your-api-key', inputImage: 'path/to/image.png', onUploadProgress: () => {}, onDownloadProgress: () => {}, options: { bg_color: 'blue', }, }); expect(appendSpy).toHaveBeenCalledWith('bg_color', 'blue'); axiosMock.mockRestore(); appendSpy.mockRestore(); }); it('should send bg_color parameter with hex color when provided', async () => { const appendSpy = jest.spyOn(FormData.prototype, 'append'); const axiosMock = jest.spyOn(axios, 'request').mockResolvedValueOnce({ data: Buffer.from('image data'), }); await rembg({ apiKey: 'your-api-key', inputImage: 'path/to/image.png', onUploadProgress: () => {}, onDownloadProgress: () => {}, options: { bg_color: '#FF5733', }, }); expect(appendSpy).toHaveBeenCalledWith('bg_color', '#FF5733'); axiosMock.mockRestore(); appendSpy.mockRestore(); }); it('should NOT send bg_color parameter when not provided', async () => { const appendSpy = jest.spyOn(FormData.prototype, 'append'); const axiosMock = jest.spyOn(axios, 'request').mockResolvedValueOnce({ data: Buffer.from('image data'), }); await rembg({ apiKey: 'your-api-key', inputImage: 'path/to/image.png', onUploadProgress: () => {}, onDownloadProgress: () => {}, options: {}, }); expect(appendSpy).not.toHaveBeenCalledWith('bg_color', expect.anything()); axiosMock.mockRestore(); appendSpy.mockRestore(); }); it('should send all three new parameters when all are provided', async () => { const appendSpy = jest.spyOn(FormData.prototype, 'append'); const axiosMock = jest.spyOn(axios, 'request').mockResolvedValueOnce({ data: Buffer.from('image data'), }); await rembg({ apiKey: 'your-api-key', inputImage: 'path/to/image.png', onUploadProgress: () => {}, onDownloadProgress: () => {}, options: { angle: 45, expand: false, bg_color: 'red', }, }); expect(appendSpy).toHaveBeenCalledWith('angle', '45'); expect(appendSpy).toHaveBeenCalledWith('expand', 'false'); expect(appendSpy).toHaveBeenCalledWith('bg_color', 'red'); axiosMock.mockRestore(); appendSpy.mockRestore(); }); }); });