@remove-background-ai/rembg.js
Version:
A simple wrapper for the https://www.rembg.com API
548 lines (450 loc) • 16 kB
text/typescript
import axios from "axios";
import MockAdapter from 'axios-mock-adapter';
import * as tmp from 'tmp-promise';
const FormData = require('form-data');
const { rembg } = require('./../index');
const mockAxios = new MockAdapter(axios);
jest.mock('tmp-promise', () => ({
file: jest.fn().mockResolvedValue({ path: 'path/to/output.png', cleanup: jest.fn() })
}));
jest.mock('fs', () => ({
createReadStream: jest.fn().mockReturnValue("stream"),
writeFileSync: jest.fn()
}));
describe('rembg', () => {
beforeEach(() => {
mockAxios.reset();
jest.clearAllMocks();
});
it('should throw an error if apiKey is not provided', async () => {
await expect(rembg({
apiKey: '',
inputImage: 'path/to/image.png',
})).rejects.toThrowError('⚠️⚠️⚠️ WARNING ⚠️⚠️⚠️: API key not provided, trials will be very limited.');
});
it('should return base64 image if returnBase64 is true', async () => {
// Mock the axios request
const axiosMock = jest.spyOn(axios, 'request').mockResolvedValueOnce({
data: Buffer.from('image data'),
});
const result = await rembg({
apiKey: 'your-api-key',
inputImage: 'path/to/image.png',
options: {
returnBase64: true,
},
});
expect(result).toEqual({
base64Image: 'data:image/png;base64,aW1hZ2UgZGF0YQ==',
});
expect(axiosMock).toHaveBeenCalledWith(expect.objectContaining({
responseType: 'arraybuffer',
}));
});
it('should return output image path and cleanup function if returnBase64 is false', async () => {
// Mock the axios request
const axiosMock = jest.spyOn(axios, 'request').mockResolvedValueOnce({
data: Buffer.from('image data'),
});
const result = await rembg({
apiKey: 'your-api-key',
inputImage: 'path/to/image.png',
options: {
returnBase64: false,
format: "png"
},
});
expect(result).toEqual({
outputImagePath: 'path/to/output.png',
cleanup: expect.any(Function),
});
expect(axiosMock).toHaveBeenCalledWith(expect.objectContaining({
responseType: 'arraybuffer',
}));
expect(tmp.file).toHaveBeenCalledWith(expect.objectContaining({
prefix: 'rembg-',
postfix: '.png',
}));
axiosMock.mockRestore();
});
it('should return the right format=WEBP', async () => {
// Mock the axios request
const axiosMock = jest.spyOn(axios, 'request').mockResolvedValueOnce({
data: Buffer.from('image data'),
});
const result = await rembg({
apiKey: 'your-api-key',
inputImage: 'path/to/image.png',
options: {
returnBase64: false,
// format: "png" no format selected => default to WEBP
},
});
expect(result).toEqual({
outputImagePath: 'path/to/output.png',
cleanup: expect.any(Function),
});
expect(axiosMock).toHaveBeenCalledWith(expect.objectContaining({
responseType: 'arraybuffer',
}));
expect(tmp.file).toHaveBeenCalledWith(expect.objectContaining({
prefix: 'rembg-',
postfix: '.WEBP',
}));
axiosMock.mockRestore();
});
it('should return the right format=PNG', async () => {
// Mock the axios request
const axiosMock = jest.spyOn(axios, 'request').mockResolvedValueOnce({
data: Buffer.from('image data'),
});
const result = await rembg({
apiKey: 'your-api-key',
inputImage: 'path/to/image.png',
options: {
returnBase64: false,
format: 'PNG'
},
});
expect(result).toEqual({
outputImagePath: 'path/to/output.png',
cleanup: expect.any(Function),
});
expect(axiosMock).toHaveBeenCalledWith(expect.objectContaining({
responseType: 'arraybuffer',
}));
expect(tmp.file).toHaveBeenCalledWith(expect.objectContaining({
prefix: 'rembg-',
postfix: '.PNG',
}));
axiosMock.mockRestore();
});
it('should throw an error if the request fails', async () => {
// Mock the axios request
const axiosMock = jest.spyOn(axios, 'request').mockRejectedValueOnce({
message: 'Request failed',
});
await expect(rembg({
apiKey: 'your-api-key',
inputImage: 'path/to/image.png',
})).rejects.toThrowError('❌ Request failed');
expect(axiosMock).toHaveBeenCalledWith(expect.objectContaining({
responseType: 'arraybuffer',
}));
});
it('should throw an error if the server responds with an error status code', async () => {
// Mock the axios request
const axiosMock = jest.spyOn(axios, 'request').mockRejectedValueOnce({
response: {
status: 500,
data: 'Internal Server Error',
},
});
await expect(rembg({
apiKey: 'your-api-key',
inputImage: 'path/to/image.png',
})).rejects.toThrowError('❌ 500 Internal Server Error');
expect(axiosMock).toHaveBeenCalledWith(expect.objectContaining({
responseType: 'arraybuffer',
}));
});
it('should throw an error if no response is received', async () => {
// Mock the axios request
const axiosMock = jest.spyOn(axios, 'request').mockRejectedValueOnce({
request: {},
});
await expect(rembg({
apiKey: 'your-api-key',
inputImage: 'path/to/image.png',
})).rejects.toThrowError();
expect(axiosMock).toHaveBeenCalledWith(expect.objectContaining({
responseType: 'arraybuffer',
}));
});
// test if a request contains "x-api-key" header with the provided API key
it('should send a request with the API key header', async () => {
// Mock the axios request
const axiosMock = jest.spyOn(axios, 'request').mockResolvedValueOnce({
data: Buffer.from('image data'),
});
await rembg({
apiKey: 'your-api-key',
inputImage: 'path/to/image.png',
});
expect(axiosMock).toHaveBeenCalledWith(expect.objectContaining({
headers: expect.objectContaining({
'x-api-key': 'your-api-key',
}),
}));
});
it('should send a request with the mask field in the form data', async () => {
// Spy on FormData.prototype.append
const appendSpy = jest.spyOn(FormData.prototype, 'append');
// Mock the axios request
const axiosMock = jest.spyOn(axios, 'request').mockResolvedValueOnce({
data: Buffer.from('image data'),
});
await rembg({
apiKey: 'your-api-key',
inputImage: 'path/to/image.png',
onUploadProgress: () => {},
onDownloadProgress: () => {},
options: {
returnMask: true,
returnBase64: false
},
});
// Check if FormData was called with the 'mask' field
expect(appendSpy).toHaveBeenCalledWith('image', expect.anything());
expect(appendSpy).toHaveBeenCalledWith('mask', 'true');
// Check axiosMock to be called with FormData
expect(axiosMock).toHaveBeenCalledWith(expect.objectContaining({
data: expect.any(FormData),
}));
// Restore mocks
jest.restoreAllMocks();
});
it('should handle buffer input', async () => {
const axiosMock = jest.spyOn(axios, 'request').mockResolvedValueOnce({
data: Buffer.from('processed image data'),
});
const inputBuffer = Buffer.from('input image data');
// Mock FormData
const mockAppend = jest.fn();
jest.spyOn(FormData.prototype, 'append').mockImplementation(mockAppend);
const result = await rembg({
apiKey: 'your-api-key',
inputImage: inputBuffer,
});
expect(result).toEqual({
outputImagePath: 'path/to/output.png',
cleanup: expect.any(Function),
});
expect(axiosMock).toHaveBeenCalledWith(expect.objectContaining({
data: expect.any(FormData),
}));
// Check if the FormData.append was called with the correct arguments
expect(mockAppend).toHaveBeenCalledWith('image', inputBuffer, { filename: 'image.png' });
axiosMock.mockRestore();
});
// New tests for base64 input
it('should handle base64 input', async () => {
const axiosMock = jest.spyOn(axios, 'request').mockResolvedValueOnce({
data: Buffer.from('processed image data'),
});
const base64Input = { base64: 'aW5wdXQgaW1hZ2UgZGF0YQ==' }; // "input image data" in base64
const result = await rembg({
apiKey: 'your-api-key',
inputImage: base64Input,
});
expect(result).toEqual({
outputImagePath: 'path/to/output.png',
cleanup: expect.any(Function),
});
expect(axiosMock).toHaveBeenCalledWith(expect.objectContaining({
data: expect.any(FormData),
}));
// Check if the FormData was created with the buffer from base64
const formDataAppendSpy = jest.spyOn(FormData.prototype, 'append');
expect(formDataAppendSpy).toHaveBeenCalledWith(
'image',
expect.any(Buffer),
{ filename: 'image.png', contentType: "image/png"},
);
axiosMock.mockRestore();
formDataAppendSpy.mockRestore();
});
// Test for invalid input type
it('should throw an error for invalid input type', async () => {
await expect(rembg({
apiKey: 'your-api-key',
inputImage: 123, // Invalid input type
})).rejects.toThrowError('Invalid input type. Must be a file path, Buffer, or an object with a base64 property.');
});
// Update existing tests to use 'input' instead of 'inputImage'
it('should return base64 image if returnBase64 is true', async () => {
const axiosMock = jest.spyOn(axios, 'request').mockResolvedValueOnce({
data: Buffer.from('image data'),
});
const result = await rembg({
apiKey: 'your-api-key',
inputImage: 'path/to/image.png',
options: {
returnBase64: true,
},
});
expect(result).toEqual({
base64Image: 'data:image/png;base64,aW1hZ2UgZGF0YQ==',
});
expect(axiosMock).toHaveBeenCalledWith(expect.objectContaining({
responseType: 'arraybuffer',
}));
axiosMock.mockRestore();
});
// Tests for new parameters: angle, expand, bg_color
describe('New parameters (angle, expand, bg_color)', () => {
beforeEach(() => {
jest.clearAllMocks();
});
it('should send angle parameter when provided', async () => {
const appendSpy = jest.spyOn(FormData.prototype, 'append');
const axiosMock = jest.spyOn(axios, 'request').mockResolvedValueOnce({
data: Buffer.from('image data'),
});
await rembg({
apiKey: 'your-api-key',
inputImage: 'path/to/image.png',
onUploadProgress: () => {},
onDownloadProgress: () => {},
options: {
angle: 90,
},
});
expect(appendSpy).toHaveBeenCalledWith('angle', '90');
axiosMock.mockRestore();
appendSpy.mockRestore();
});
it('should NOT send angle parameter when not provided', async () => {
const appendSpy = jest.spyOn(FormData.prototype, 'append');
const axiosMock = jest.spyOn(axios, 'request').mockResolvedValueOnce({
data: Buffer.from('image data'),
});
await rembg({
apiKey: 'your-api-key',
inputImage: 'path/to/image.png',
onUploadProgress: () => {},
onDownloadProgress: () => {},
options: {},
});
expect(appendSpy).not.toHaveBeenCalledWith('angle', expect.anything());
axiosMock.mockRestore();
appendSpy.mockRestore();
});
it('should send expand parameter when explicitly provided as true', async () => {
const appendSpy = jest.spyOn(FormData.prototype, 'append');
const axiosMock = jest.spyOn(axios, 'request').mockResolvedValueOnce({
data: Buffer.from('image data'),
});
await rembg({
apiKey: 'your-api-key',
inputImage: 'path/to/image.png',
onUploadProgress: () => {},
onDownloadProgress: () => {},
options: {
expand: true,
},
});
expect(appendSpy).toHaveBeenCalledWith('expand', 'true');
axiosMock.mockRestore();
appendSpy.mockRestore();
});
it('should send expand parameter when explicitly provided as false', async () => {
const appendSpy = jest.spyOn(FormData.prototype, 'append');
const axiosMock = jest.spyOn(axios, 'request').mockResolvedValueOnce({
data: Buffer.from('image data'),
});
await rembg({
apiKey: 'your-api-key',
inputImage: 'path/to/image.png',
onUploadProgress: () => {},
onDownloadProgress: () => {},
options: {
expand: false,
},
});
expect(appendSpy).toHaveBeenCalledWith('expand', 'false');
axiosMock.mockRestore();
appendSpy.mockRestore();
});
it('should send expand parameter with default value "true" when not provided', async () => {
const appendSpy = jest.spyOn(FormData.prototype, 'append');
const axiosMock = jest.spyOn(axios, 'request').mockResolvedValueOnce({
data: Buffer.from('image data'),
});
await rembg({
apiKey: 'your-api-key',
inputImage: 'path/to/image.png',
onUploadProgress: () => {},
onDownloadProgress: () => {},
options: {},
});
expect(appendSpy).toHaveBeenCalledWith('expand', 'true');
axiosMock.mockRestore();
appendSpy.mockRestore();
});
it('should send bg_color parameter when provided', async () => {
const appendSpy = jest.spyOn(FormData.prototype, 'append');
const axiosMock = jest.spyOn(axios, 'request').mockResolvedValueOnce({
data: Buffer.from('image data'),
});
await rembg({
apiKey: 'your-api-key',
inputImage: 'path/to/image.png',
onUploadProgress: () => {},
onDownloadProgress: () => {},
options: {
bg_color: 'blue',
},
});
expect(appendSpy).toHaveBeenCalledWith('bg_color', 'blue');
axiosMock.mockRestore();
appendSpy.mockRestore();
});
it('should send bg_color parameter with hex color when provided', async () => {
const appendSpy = jest.spyOn(FormData.prototype, 'append');
const axiosMock = jest.spyOn(axios, 'request').mockResolvedValueOnce({
data: Buffer.from('image data'),
});
await rembg({
apiKey: 'your-api-key',
inputImage: 'path/to/image.png',
onUploadProgress: () => {},
onDownloadProgress: () => {},
options: {
bg_color: '#FF5733',
},
});
expect(appendSpy).toHaveBeenCalledWith('bg_color', '#FF5733');
axiosMock.mockRestore();
appendSpy.mockRestore();
});
it('should NOT send bg_color parameter when not provided', async () => {
const appendSpy = jest.spyOn(FormData.prototype, 'append');
const axiosMock = jest.spyOn(axios, 'request').mockResolvedValueOnce({
data: Buffer.from('image data'),
});
await rembg({
apiKey: 'your-api-key',
inputImage: 'path/to/image.png',
onUploadProgress: () => {},
onDownloadProgress: () => {},
options: {},
});
expect(appendSpy).not.toHaveBeenCalledWith('bg_color', expect.anything());
axiosMock.mockRestore();
appendSpy.mockRestore();
});
it('should send all three new parameters when all are provided', async () => {
const appendSpy = jest.spyOn(FormData.prototype, 'append');
const axiosMock = jest.spyOn(axios, 'request').mockResolvedValueOnce({
data: Buffer.from('image data'),
});
await rembg({
apiKey: 'your-api-key',
inputImage: 'path/to/image.png',
onUploadProgress: () => {},
onDownloadProgress: () => {},
options: {
angle: 45,
expand: false,
bg_color: 'red',
},
});
expect(appendSpy).toHaveBeenCalledWith('angle', '45');
expect(appendSpy).toHaveBeenCalledWith('expand', 'false');
expect(appendSpy).toHaveBeenCalledWith('bg_color', 'red');
axiosMock.mockRestore();
appendSpy.mockRestore();
});
});
});