UNPKG

alwaysai

Version:

The alwaysAI command-line interface (CLI)

112 lines (98 loc) 3.61 kB
import { join } from 'path'; import { createReadStream, stat } from 'fs'; import { CLOUD_API_MODEL_VERSION_PACKAGES_PATH, CloudApiUrl } from '@alwaysai/cloud-api'; import { CliTerseError } from '@alwaysai/alwayscli'; import * as tar from 'tar'; import * as tempy from 'tempy'; import { promisify } from 'util'; import { ModelJsonFile } from './model-json-file'; import { JsSpawner, throwIfNotOk, copyFiles, stringifyError, logger } from '../../util'; import { CliRpcClient, CliAuthenticationClient, getSystemId } from '../../infrastructure'; import { ModelId } from './model-id'; import { fetchFilestream } from '../../util/fetch'; import { Readable } from 'stream'; async function getModelVersion(id: string, version: number) { try { return await CliRpcClient().getModelVersion({ id, version }); } catch (exception) { logger.error(stringifyError(exception)); throw new CliTerseError(`Model ${id} not found!`); } } export const modelPackageCloudClient = { async download(id: string, version: number) { const { uuid } = await getModelVersion(id, version); try { const modelPackageUrl = ModelPackageUrl(uuid); const authorizationHeader = await CliAuthenticationClient().getAuthorizationHeader(); const filestream = await fetchFilestream(modelPackageUrl, { method: 'GET', headers: { ...authorizationHeader, 'Content-Length': '0' } }); return filestream; } catch (exception) { logger.error(stringifyError(exception)); throw new CliTerseError(`Failed to download model ${id}!`); } }, async publish(dir = process.cwd()) { const modelJson = ModelJsonFile(dir).read(); const rpcClient = CliRpcClient(); const { uuid } = await rpcClient.createModelVersion(modelJson); const authorizationHeader = await CliAuthenticationClient().getAuthorizationHeader(); // Ensure top level directory matches model name const { name } = ModelId.parse(modelJson.id); const tmpDir = tempy.directory({ prefix: 'alwaysai' }); const tmpModelDir = join(tmpDir, name); await JsSpawner().mkdirp(tmpModelDir); await copyFiles(JsSpawner({ path: dir }), JsSpawner({ path: tmpModelDir })); const modelPackagePath = tempy.file(); await tar.create({ cwd: tmpDir, gzip: true, file: modelPackagePath }, [ name ]); await JsSpawner().rimraf(tmpDir); const stats = await promisify(stat)(modelPackagePath); const modelPackageStream = createReadStream(modelPackagePath); // Convert to web stream since Typescript maps to DOM types at compile time, // which don't include all supported runtime types. const webModelPackageStream = Readable.toWeb(modelPackageStream); // Again, typescript is mapping to DOM types which don't include the duplex // property. This change augments the type manually and avoids a more // in-depth change to the tsconfig. interface NodeRequestInit extends RequestInit { duplex?: 'half'; } const response = await fetch(ModelPackageUrl(uuid), { method: 'PUT', headers: { 'Content-Type': 'application/gzip', 'Content-Length': stats.size.toString(), ...authorizationHeader }, body: webModelPackageStream, duplex: 'half' } as NodeRequestInit); await throwIfNotOk(response); await rpcClient.finalizeModelVersion(uuid); return uuid; } }; function ModelPackageUrl(uuid: string) { const cloudApiUrl = CloudApiUrl(getSystemId()); return `${cloudApiUrl}${CLOUD_API_MODEL_VERSION_PACKAGES_PATH}/${uuid}`; }