alwaysai
Version:
The alwaysAI command-line interface (CLI)
112 lines (98 loc) • 3.61 kB
text/typescript
import { join } from 'path';
import { createReadStream, stat } from 'fs';
import {
CLOUD_API_MODEL_VERSION_PACKAGES_PATH,
CloudApiUrl
} from '@alwaysai/cloud-api';
import { CliTerseError } from '@alwaysai/alwayscli';
import * as tar from 'tar';
import * as tempy from 'tempy';
import { promisify } from 'util';
import { ModelJsonFile } from './model-json-file';
import {
JsSpawner,
throwIfNotOk,
copyFiles,
stringifyError,
logger
} from '../../util';
import {
CliRpcClient,
CliAuthenticationClient,
getSystemId
} from '../../infrastructure';
import { ModelId } from './model-id';
import { fetchFilestream } from '../../util/fetch';
import { Readable } from 'stream';
async function getModelVersion(id: string, version: number) {
try {
return await CliRpcClient().getModelVersion({ id, version });
} catch (exception) {
logger.error(stringifyError(exception));
throw new CliTerseError(`Model ${id} not found!`);
}
}
export const modelPackageCloudClient = {
async download(id: string, version: number) {
const { uuid } = await getModelVersion(id, version);
try {
const modelPackageUrl = ModelPackageUrl(uuid);
const authorizationHeader =
await CliAuthenticationClient().getAuthorizationHeader();
const filestream = await fetchFilestream(modelPackageUrl, {
method: 'GET',
headers: { ...authorizationHeader, 'Content-Length': '0' }
});
return filestream;
} catch (exception) {
logger.error(stringifyError(exception));
throw new CliTerseError(`Failed to download model ${id}!`);
}
},
async publish(dir = process.cwd()) {
const modelJson = ModelJsonFile(dir).read();
const rpcClient = CliRpcClient();
const { uuid } = await rpcClient.createModelVersion(modelJson);
const authorizationHeader =
await CliAuthenticationClient().getAuthorizationHeader();
// Ensure top level directory matches model name
const { name } = ModelId.parse(modelJson.id);
const tmpDir = tempy.directory({ prefix: 'alwaysai' });
const tmpModelDir = join(tmpDir, name);
await JsSpawner().mkdirp(tmpModelDir);
await copyFiles(JsSpawner({ path: dir }), JsSpawner({ path: tmpModelDir }));
const modelPackagePath = tempy.file();
await tar.create({ cwd: tmpDir, gzip: true, file: modelPackagePath }, [
name
]);
await JsSpawner().rimraf(tmpDir);
const stats = await promisify(stat)(modelPackagePath);
const modelPackageStream = createReadStream(modelPackagePath);
// Convert to web stream since Typescript maps to DOM types at compile time,
// which don't include all supported runtime types.
const webModelPackageStream = Readable.toWeb(modelPackageStream);
// Again, typescript is mapping to DOM types which don't include the duplex
// property. This change augments the type manually and avoids a more
// in-depth change to the tsconfig.
interface NodeRequestInit extends RequestInit {
duplex?: 'half';
}
const response = await fetch(ModelPackageUrl(uuid), {
method: 'PUT',
headers: {
'Content-Type': 'application/gzip',
'Content-Length': stats.size.toString(),
...authorizationHeader
},
body: webModelPackageStream,
duplex: 'half'
} as NodeRequestInit);
await throwIfNotOk(response);
await rpcClient.finalizeModelVersion(uuid);
return uuid;
}
};
function ModelPackageUrl(uuid: string) {
const cloudApiUrl = CloudApiUrl(getSystemId());
return `${cloudApiUrl}${CLOUD_API_MODEL_VERSION_PACKAGES_PATH}/${uuid}`;
}