aws-cdk-lib
Version:
Version 2 of the AWS Cloud Development Kit library
2 lines (1 loc) • 8.21 kB
JavaScript
"use strict";var _a;Object.defineProperty(exports,"__esModule",{value:!0}),exports.SageMakerCreateTrainingJob=void 0;const jsiiDeprecationWarnings=require("../../../.warnings.jsii.js"),JSII_RTTI_SYMBOL_1=Symbol.for("jsii.rtti"),ec2=require("../../../aws-ec2"),iam=require("../../../aws-iam"),sfn=require("../../../aws-stepfunctions"),core_1=require("../../../core"),base_types_1=require("./base-types"),utils_1=require("./private/utils"),task_utils_1=require("../private/task-utils");class SageMakerCreateTrainingJob extends sfn.TaskStateBase{constructor(scope,id,props){super(scope,id,props),this.props=props,this.connections=new ec2.Connections,this.securityGroups=[];try{jsiiDeprecationWarnings.aws_cdk_lib_aws_stepfunctions_tasks_SageMakerCreateTrainingJobProps(props)}catch(error){throw process.env.JSII_DEBUG!=="1"&&error.name==="DeprecationError"&&Error.captureStackTrace(error,SageMakerCreateTrainingJob),error}if(this.integrationPattern=props.integrationPattern||sfn.IntegrationPattern.REQUEST_RESPONSE,task_utils_1.validatePatternSupported(this.integrationPattern,SageMakerCreateTrainingJob.SUPPORTED_INTEGRATION_PATTERNS),this.resourceConfig=props.resourceConfig||{instanceCount:1,instanceType:ec2.InstanceType.of(ec2.InstanceClass.M4,ec2.InstanceSize.XLARGE),volumeSize:core_1.Size.gibibytes(10)},this.stoppingCondition=props.stoppingCondition||{maxRuntime:core_1.Duration.hours(1)},!props.algorithmSpecification.algorithmName&&!props.algorithmSpecification.trainingImage)throw new Error("Must define either an algorithm name or training image URI in the algorithm specification");this.algorithmSpecification=props.algorithmSpecification.trainingInputMode?props.algorithmSpecification:{...props.algorithmSpecification,trainingInputMode:base_types_1.InputMode.FILE},this.inputDataConfig=props.inputDataConfig.map(config=>config.dataSource.s3DataSource.s3DataType?config:{...config,dataSource:{s3DataSource:{...config.dataSource.s3DataSource,s3DataType:base_types_1.S3DataType.S3_PREFIX}}}),props.vpcConfig&&(this.vpc=props.vpcConfig.vpc,this.subnets=props.vpcConfig.subnets?this.vpc.selectSubnets(props.vpcConfig.subnets).subnetIds:this.vpc.selectSubnets().subnetIds),this.taskPolicies=this.makePolicyStatements()}get role(){if(this._role===void 0)throw new Error("role not available yet--use the object in a Task first");return this._role}get grantPrincipal(){if(this._grantPrincipal===void 0)throw new Error("Principal not available yet--use the object in a Task first");return this._grantPrincipal}addSecurityGroup(securityGroup){try{jsiiDeprecationWarnings.aws_cdk_lib_aws_ec2_ISecurityGroup(securityGroup)}catch(error){throw process.env.JSII_DEBUG!=="1"&&error.name==="DeprecationError"&&Error.captureStackTrace(error,this.addSecurityGroup),error}this.securityGroups.push(securityGroup)}_renderTask(){return{Resource:task_utils_1.integrationResourceArn("sagemaker","createTrainingJob",this.integrationPattern),Parameters:sfn.FieldUtils.renderObject(this.renderParameters())}}renderParameters(){return{TrainingJobName:this.props.trainingJobName,EnableNetworkIsolation:this.props.enableNetworkIsolation,RoleArn:this._role.roleArn,...this.renderAlgorithmSpecification(this.algorithmSpecification),...this.renderInputDataConfig(this.inputDataConfig),...this.renderOutputDataConfig(this.props.outputDataConfig),...this.renderResourceConfig(this.resourceConfig),...this.renderStoppingCondition(this.stoppingCondition),...this.renderHyperparameters(this.props.hyperparameters),...utils_1.renderTags(this.props.tags),...this.renderVpcConfig(this.props.vpcConfig),...utils_1.renderEnvironment(this.props.environment)}}renderAlgorithmSpecification(spec){return{AlgorithmSpecification:{TrainingInputMode:spec.trainingInputMode,...spec.trainingImage?{TrainingImage:spec.trainingImage.bind(this).imageUri}:{},...spec.algorithmName?{AlgorithmName:spec.algorithmName}:{},...spec.metricDefinitions?{MetricDefinitions:spec.metricDefinitions.map(metric=>({Name:metric.name,Regex:metric.regex}))}:{}}}}renderInputDataConfig(config){return{InputDataConfig:config.map(channel=>({ChannelName:channel.channelName,DataSource:{S3DataSource:{S3Uri:channel.dataSource.s3DataSource.s3Location.bind(this,{forReading:!0}).uri,S3DataType:channel.dataSource.s3DataSource.s3DataType,...channel.dataSource.s3DataSource.s3DataDistributionType?{S3DataDistributionType:channel.dataSource.s3DataSource.s3DataDistributionType}:{},...channel.dataSource.s3DataSource.attributeNames?{AttributeNames:channel.dataSource.s3DataSource.attributeNames}:{}}},...channel.compressionType?{CompressionType:channel.compressionType}:{},...channel.contentType?{ContentType:channel.contentType}:{},...channel.inputMode?{InputMode:channel.inputMode}:{},...channel.recordWrapperType?{RecordWrapperType:channel.recordWrapperType}:{}}))}}renderOutputDataConfig(config){return{OutputDataConfig:{S3OutputPath:config.s3OutputLocation.bind(this,{forWriting:!0}).uri,...config.encryptionKey?{KmsKeyId:config.encryptionKey.keyArn}:{}}}}renderResourceConfig(config){return{ResourceConfig:{InstanceCount:config.instanceCount,InstanceType:sfn.JsonPath.isEncodedJsonPath(config.instanceType.toString())?config.instanceType.toString():`ml.${config.instanceType}`,VolumeSizeInGB:config.volumeSize.toGibibytes(),...config.volumeEncryptionKey?{VolumeKmsKeyId:config.volumeEncryptionKey.keyArn}:{}}}}renderStoppingCondition(config){return{StoppingCondition:{MaxRuntimeInSeconds:config.maxRuntime&&config.maxRuntime.toSeconds()}}}renderHyperparameters(params){return params?{HyperParameters:params}:{}}renderVpcConfig(config){return config?{VpcConfig:{SecurityGroupIds:core_1.Lazy.list({produce:()=>this.securityGroups.map(sg=>sg.securityGroupId)}),Subnets:this.subnets}}:{}}makePolicyStatements(){this._grantPrincipal=this._role=this.props.role||new iam.Role(this,"SagemakerRole",{assumedBy:new iam.ServicePrincipal("sagemaker.amazonaws.com"),inlinePolicies:{CreateTrainingJob:new iam.PolicyDocument({statements:[new iam.PolicyStatement({actions:["cloudwatch:PutMetricData","logs:CreateLogStream","logs:PutLogEvents","logs:CreateLogGroup","logs:DescribeLogStreams","ecr:GetAuthorizationToken",...this.props.vpcConfig?["ec2:CreateNetworkInterface","ec2:CreateNetworkInterfacePermission","ec2:DeleteNetworkInterface","ec2:DeleteNetworkInterfacePermission","ec2:DescribeNetworkInterfaces","ec2:DescribeVpcs","ec2:DescribeDhcpOptions","ec2:DescribeSubnets","ec2:DescribeSecurityGroups"]:[]],resources:["*"]})]})}}),this.props.outputDataConfig.encryptionKey&&this.props.outputDataConfig.encryptionKey.grantEncrypt(this._role),this.props.resourceConfig&&this.props.resourceConfig.volumeEncryptionKey&&this.props.resourceConfig.volumeEncryptionKey.grant(this._role,"kms:CreateGrant"),this.vpc&&this.securityGroup===void 0&&(this.securityGroup=new ec2.SecurityGroup(this,"TrainJobSecurityGroup",{vpc:this.vpc}),this.connections.addSecurityGroup(this.securityGroup),this.securityGroups.push(this.securityGroup));const stack=core_1.Stack.of(this),policyStatements=[new iam.PolicyStatement({actions:["sagemaker:CreateTrainingJob","sagemaker:DescribeTrainingJob","sagemaker:StopTrainingJob"],resources:[stack.formatArn({service:"sagemaker",resource:"training-job",resourceName:sfn.JsonPath.isEncodedJsonPath(this.props.trainingJobName)?"*":`${this.props.trainingJobName}*`})]}),new iam.PolicyStatement({actions:["sagemaker:ListTags"],resources:["*"]}),new iam.PolicyStatement({actions:["iam:PassRole"],resources:[this._role.roleArn],conditions:{StringEquals:{"iam:PassedToService":"sagemaker.amazonaws.com"}}})];return this.integrationPattern===sfn.IntegrationPattern.RUN_JOB&&policyStatements.push(new iam.PolicyStatement({actions:["events:PutTargets","events:PutRule","events:DescribeRule"],resources:[stack.formatArn({service:"events",resource:"rule",resourceName:"StepFunctionsGetEventsForSageMakerTrainingJobsRule"})]})),policyStatements}}exports.SageMakerCreateTrainingJob=SageMakerCreateTrainingJob,_a=JSII_RTTI_SYMBOL_1,SageMakerCreateTrainingJob[_a]={fqn:"aws-cdk-lib.aws_stepfunctions_tasks.SageMakerCreateTrainingJob",version:"2.70.0"},SageMakerCreateTrainingJob.SUPPORTED_INTEGRATION_PATTERNS=[sfn.IntegrationPattern.REQUEST_RESPONSE,sfn.IntegrationPattern.RUN_JOB];