@autorest/go
Version:
AutoRest Go Generator
1,011 lines (1,010 loc) • 75 kB
JavaScript
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as go from '../../codemodel.go/src/index.js';
import { ensureNameCase } from '../../naming.go/src/naming.js';
import { capitalize, comment, uncapitalize } from '@azure-tools/codegen';
import { values } from '@azure-tools/linq';
import * as helpers from './helpers.js';
import { ImportManager } from './imports.js';
import { CodegenError } from './errors.js';
// represents the generated content for an operation group
export class OperationGroupContent {
name;
content;
constructor(name, content) {
this.name = name;
this.content = content;
}
}
// Creates the content for all <operation>.go files
export async function generateOperations(codeModel) {
// generate protocol operations
const operations = new Array();
if (codeModel.clients.length === 0) {
return operations;
}
const azureARM = codeModel.type === 'azure-arm';
for (const client of codeModel.clients) {
// the list of packages to import
const imports = new ImportManager();
if (client.methods.length > 0) {
// add standard imports for clients with methods.
// clients that are purely hierarchical (i.e. having no APIs) won't need them.
imports.add('net/http');
imports.add('github.com/Azure/azure-sdk-for-go/sdk/azcore/policy');
imports.add('github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime');
}
imports.add(azureARM ? 'github.com/Azure/azure-sdk-for-go/sdk/azcore/arm' : 'github.com/Azure/azure-sdk-for-go/sdk/azcore');
// generate client type
let clientText = helpers.formatDocComment(client.docs);
clientText += '// Don\'t use this type directly, use ';
if (client.instance?.kind === 'constructable' && client.instance.constructors.length === 1) {
clientText += `${client.instance.constructors[0].name}() instead.\n`;
}
else if (client.parent) {
// find the accessor method
let accessorMethod;
for (const clientAccessor of client.parent.clientAccessors) {
if (clientAccessor.subClient === client) {
accessorMethod = clientAccessor.name;
break;
}
}
if (!accessorMethod) {
throw new CodegenError('InternalError', `didn't find accessor method for client ${client.name} on parent client ${client.parent.name}`);
}
clientText += `[${client.parent.name}.${accessorMethod}] instead.\n`;
}
else {
clientText += 'a constructor function instead.\n';
}
clientText += `type ${client.name} struct {\n`;
clientText += `\tinternal *${azureARM ? 'arm' : 'azcore'}.Client\n`;
// check for any optional host params
const optionalParams = new Array();
const isParamPointer = function (param) {
// for client params, only optional and flag types are passed by pointer
return param.style === 'flag' || param.style === 'optional';
};
// now emit any client params (non parameterized host params case)
if (client.parameters.length > 0) {
const addedGroups = new Set();
for (const clientParam of values(client.parameters)) {
if (go.isLiteralParameter(clientParam.style)) {
continue;
}
if (clientParam.group) {
if (!addedGroups.has(clientParam.group.groupName)) {
clientText += `\t${uncapitalize(clientParam.group.groupName)} ${!isParamPointer(clientParam) ? '' : '*'}${clientParam.group.groupName}\n`;
addedGroups.add(clientParam.group.groupName);
}
continue;
}
clientText += `\t${clientParam.name} `;
if (!isParamPointer(clientParam)) {
clientText += `${go.getTypeDeclaration(clientParam.type)}\n`;
}
else {
clientText += `${helpers.formatParameterTypeName(clientParam)}\n`;
}
if (!go.isRequiredParameter(clientParam.style)) {
optionalParams.push(clientParam);
}
}
}
// end of client definition
clientText += '}\n\n';
clientText += generateConstructors(client, codeModel.type, imports);
// generate client accessors and operations
let opText = '';
for (const clientAccessor of client.clientAccessors) {
opText += `// ${clientAccessor.name} creates a new instance of [${clientAccessor.subClient.name}].\n`;
opText += `func (client *${client.name}) ${clientAccessor.name}() *${clientAccessor.subClient.name} {\n`;
opText += `\treturn &${clientAccessor.subClient.name}{\n`;
opText += '\t\tinternal: client.internal,\n';
// propagate all client params
for (const param of client.parameters) {
if (go.isLiteralParameter(param.style)) {
continue;
}
opText += `\t\t${param.name}: client.${param.name},\n`;
}
opText += '\t}\n}\n\n';
}
const nextPageMethods = new Array();
for (const method of client.methods) {
// protocol creation can add imports to the list so
// it must be done before the imports are written out
if (go.isLROMethod(method)) {
// generate Begin method
opText += generateLROBeginMethod(method, imports, codeModel.options.injectSpans, codeModel.options.generateFakes);
}
opText += generateOperation(method, imports, codeModel.options.injectSpans, codeModel.options.generateFakes);
opText += createProtocolRequest(azureARM, method, imports);
if (method.kind !== 'lroMethod') {
// LRO responses are handled elsewhere, with the exception of pageable LROs
opText += createProtocolResponse(method, imports);
}
if ((method.kind === 'lroPageableMethod' || method.kind === 'pageableMethod') && method.nextPageMethod && !nextPageMethods.includes(method.nextPageMethod)) {
// track the next page methods to generate as multiple operations can use the same next page operation
nextPageMethods.push(method.nextPageMethod);
}
}
for (const method of nextPageMethods) {
opText += createProtocolRequest(azureARM, method, imports);
}
// stitch it all together
let text = helpers.contentPreamble(codeModel);
text += imports.text();
text += clientText;
text += opText;
operations.push(new OperationGroupContent(client.name, text));
}
return operations;
}
/**
* generates all modeled client constructors and client options types.
* if there are no client constructors, the empty string is returned.
*
* @param client the client for which to generate constructors and the client options type
* @param imports the import manager currently in scope
* @returns the client constructor code or the empty string
*/
function generateConstructors(client, type, imports) {
if (client.instance?.kind !== 'constructable') {
return '';
}
const clientOptions = client.instance.options;
let ctorText = '';
if (clientOptions.kind === 'clientOptions') {
// for non-ARM, the options type will always be a parameter group
ctorText += `// ${clientOptions.name} contains the optional values for creating a [${client.name}].\n`;
ctorText += `type ${clientOptions.name} struct {\n\tazcore.ClientOptions\n`;
for (const param of clientOptions.parameters) {
if (go.isAPIVersionParameter(param)) {
// we use azcore.ClientOptions.APIVersion
continue;
}
ctorText += helpers.formatDocCommentWithPrefix(ensureNameCase(param.name), param.docs);
if (go.isClientSideDefault(param.style)) {
if (!param.docs.description && !param.docs.summary) {
ctorText += '\n';
}
ctorText += `\t${comment(`The default value is ${helpers.formatLiteralValue(param.style.defaultValue, false)}`, '// ')}.\n`;
}
ctorText += `\t${ensureNameCase(param.name)} *${go.getTypeDeclaration(param.type)}\n`;
}
ctorText += '}\n\n';
}
for (const constructor of client.instance.constructors) {
const ctorParams = new Array();
const paramDocs = new Array();
// ctor params can also be present in the supplemental endpoint parameters
const consolidatedCtorParams = new Array();
if (client.instance.endpoint) {
consolidatedCtorParams.push(client.instance.endpoint.parameter);
if (client.instance.endpoint.supplemental) {
consolidatedCtorParams.push(...client.instance.endpoint.supplemental.parameters);
}
}
for (const param of helpers.sortClientParameters(constructor.parameters, type)) {
if (!consolidatedCtorParams.includes(param)) {
consolidatedCtorParams.push(param);
}
}
for (const ctorParam of consolidatedCtorParams) {
if (!go.isRequiredParameter(ctorParam.style)) {
// param is part of the options group
continue;
}
imports.addImportForType(ctorParam.type);
ctorParams.push(`${ctorParam.name} ${helpers.formatParameterTypeName(ctorParam)}`);
if (ctorParam.docs.summary || ctorParam.docs.description) {
paramDocs.push(helpers.formatCommentAsBulletItem(ctorParam.name, ctorParam.docs));
}
}
const emitProlog = function (optionsTypeName, tokenAuth, plOpts) {
imports.add('github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime');
let bodyText = `\tif options == nil {\n\t\toptions = &${optionsTypeName}{}\n\t}\n`;
let apiVersionConfig = '';
// check if there's an api version parameter
let apiVersionParam;
for (const param of consolidatedCtorParams) {
switch (param.kind) {
case 'headerScalarParam':
case 'pathScalarParam':
case 'queryScalarParam':
case 'uriParam':
if (param.isApiVersion) {
apiVersionParam = param;
}
}
}
if (tokenAuth) {
imports.add('github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud');
imports.add('fmt');
imports.add('reflect');
bodyText += '\tif reflect.ValueOf(options.Cloud).IsZero() {\n';
bodyText += '\t\toptions.Cloud = cloud.AzurePublic\n\t}\n';
bodyText += '\tc, ok := options.Cloud.Services[ServiceName]\n';
bodyText += '\tif !ok {\n';
bodyText += '\t\treturn nil, fmt.Errorf("provided Cloud field is missing configuration for %s", ServiceName)\n';
bodyText += '\t} else if c.Audience == "" {\n';
bodyText += '\t\treturn nil, fmt.Errorf("provided Cloud field is missing Audience for %s", ServiceName)\n\t}\n';
}
if (apiVersionParam) {
let location;
let name;
switch (apiVersionParam.kind) {
case 'headerScalarParam':
location = 'Header';
name = apiVersionParam.headerName;
break;
case 'pathScalarParam':
case 'uriParam':
location = 'Path';
// name isn't used for the path case
break;
case 'queryScalarParam':
location = 'QueryParam';
name = apiVersionParam.queryParameter;
break;
}
if (name) {
name = `\n\t\t\tName: "${name}",`;
}
else {
name = '';
}
apiVersionConfig = `\n\t\tAPIVersion: runtime.APIVersionOptions{${name}\n\t\t\tLocation: runtime.APIVersionLocation${location},\n\t\t},`;
if (!plOpts) {
apiVersionConfig += '\n';
}
}
bodyText += `\tcl, err := azcore.NewClient(moduleName, moduleVersion, runtime.PipelineOptions{${apiVersionConfig}${plOpts ?? ''}}, &options.ClientOptions)\n`;
return bodyText;
};
// check if there's a credential parameter
let credentialParam;
for (const param of constructor.parameters) {
if (param.kind === 'credentialParam') {
credentialParam = param;
break;
}
}
let prolog;
if (credentialParam) {
switch (credentialParam.type.kind) {
case 'tokenCredential':
imports.add('github.com/Azure/azure-sdk-for-go/sdk/azcore');
paramDocs.push(helpers.formatCommentAsBulletItem('credential', { summary: 'used to authorize requests. Usually a credential from azidentity.' }));
switch (clientOptions.kind) {
case 'clientOptions': {
imports.add('github.com/Azure/azure-sdk-for-go/sdk/azcore/policy');
const tokenPolicyOpts = '&policy.BearerTokenOptions{\n\t\t\tInsecureAllowCredentialWithHTTP: options.InsecureAllowCredentialWithHTTP,\n\t\t}';
// we assume a single scope. this is enforced when adapting the data from tcgc
const tokenPolicy = `\n\t\tPerCall: []policy.Policy{\n\t\truntime.NewBearerTokenPolicy(credential, []string{c.Audience + "${helpers.splitScope(credentialParam.type.scopes[0]).scope}"}, ${tokenPolicyOpts}),\n\t\t},\n`;
prolog = emitProlog(go.getTypeDeclaration(clientOptions), true, tokenPolicy);
break;
}
case 'armClientOptions':
// this is the ARM case
imports.add('github.com/Azure/azure-sdk-for-go/sdk/azcore/arm');
prolog = '\tcl, err := arm.NewClient(moduleName, moduleVersion, credential, options)\n';
break;
}
break;
}
}
else {
prolog = emitProlog(go.getTypeDeclaration(clientOptions), false);
}
// add client options last
ctorParams.push(`options ${helpers.formatParameterTypeName(clientOptions)}`);
paramDocs.push(helpers.formatCommentAsBulletItem('options', { summary: 'Contains optional client configuration. Pass nil to accept the default values.' }));
ctorText += `// ${constructor.name} creates a new instance of ${client.name} with the specified values.\n`;
for (const doc of paramDocs) {
ctorText += doc;
}
ctorText += `func ${constructor.name}(${ctorParams.join(', ')}) (*${client.name}, error) {\n`;
ctorText += prolog;
ctorText += '\tif err != nil {\n';
ctorText += '\t\treturn nil, err\n';
ctorText += '\t}\n';
// handle any client-side defaults
if (clientOptions.kind === 'clientOptions') {
for (const param of clientOptions.parameters) {
if (go.isClientSideDefault(param.style)) {
let name;
if (go.isAPIVersionParameter(param)) {
name = 'APIVersion';
}
else {
name = ensureNameCase(param.name);
}
ctorText += `\t${param.name} := ${helpers.formatLiteralValue(param.style.defaultValue, false)}\n`;
ctorText += `\tif options.${name} != ${helpers.zeroValue(param)} {\n\t\t${param.name} = ${helpers.star(param.byValue)}options.${name}\n\t}\n`;
}
}
}
// construct the supplemental path and join it to the endpoint
if (client.instance.endpoint?.supplemental) {
imports.add('strings');
ctorText += `\thost := "${client.instance.endpoint.supplemental.path}"\n`;
for (const param of client.instance.endpoint.supplemental.parameters) {
ctorText += `\thost = strings.ReplaceAll(host, "{${param.uriPathSegment}}", ${helpers.formatValue(param.name, param.type, imports)})\n`;
}
// the endpoint param is always the first ctor param
const endpointParam = client.instance.constructors[0].parameters[0];
ctorText += `\t${endpointParam.name} = runtime.JoinPaths(${endpointParam.name}, host)\n`;
}
// construct client literal
let clientVar = 'client';
// ensure clientVar doesn't collide with any params
for (const param of consolidatedCtorParams) {
if (param.name === clientVar) {
clientVar = ensureNameCase(client.name, true);
break;
}
}
ctorText += `\t${clientVar} := &${client.name}{\n`;
// NOTE: we don't enumerate consolidatedCtorParams here
// as any supplemental endpoint params are ephemeral and
// consumed during client construction.
for (const parameter of values(client.parameters)) {
if (go.isLiteralParameter(parameter.style)) {
continue;
}
// each client field will have a matching parameter with the same name
ctorText += `\t\t${parameter.name}: ${parameter.name},\n`;
}
ctorText += '\tinternal: cl,\n';
ctorText += '\t}\n';
ctorText += `\treturn ${clientVar}, nil\n`;
ctorText += '}\n\n';
}
return ctorText;
}
// use this to generate the code that will help process values returned in response headers
function formatHeaderResponseValue(headerResp, imports, respObj, zeroResp) {
// dictionaries are handled slightly different so we do that first
if (headerResp.kind === 'headerMapResponse') {
imports.add('github.com/Azure/azure-sdk-for-go/sdk/azcore/to');
imports.add('strings');
const headerPrefix = headerResp.headerName;
let text = '\tfor hh := range resp.Header {\n';
text += `\t\tif len(hh) > len("${headerPrefix}") && strings.EqualFold(hh[:len("${headerPrefix}")], "${headerPrefix}") {\n`;
text += `\t\t\tif ${respObj}.${headerResp.fieldName} == nil {\n`;
text += `\t\t\t\t${respObj}.${headerResp.fieldName} = map[string]*string{}\n`;
text += '\t\t\t}\n';
text += `\t\t\t${respObj}.${headerResp.fieldName}[hh[len("${headerPrefix}"):]] = to.Ptr(resp.Header.Get(hh))\n`;
text += '\t\t}\n';
text += '\t}\n';
return text;
}
let text = `\tif val := resp.Header.Get("${headerResp.headerName}"); val != "" {\n`;
let name = uncapitalize(headerResp.fieldName);
let byRef = '&';
switch (headerResp.type.kind) {
case 'constant':
text += `\t\t${respObj}.${headerResp.fieldName} = (*${headerResp.type.name})(&val)\n`;
text += '\t}\n';
return text;
case 'encodedBytes':
// a base-64 encoded value in string format
imports.add('encoding/base64');
text += `\t\t${name}, err := base64.${helpers.formatBytesEncoding(headerResp.type.encoding)}Encoding.DecodeString(val)\n`;
byRef = '';
break;
case 'literal':
text += `\t\t${respObj}.${headerResp.fieldName} = &val\n`;
text += '\t}\n';
return text;
case 'scalar':
imports.add('strconv');
switch (headerResp.type.type) {
case 'bool':
text += `\t\t${name}, err := strconv.ParseBool(val)\n`;
break;
case 'float32':
text += `\t\t${name}32, err := strconv.ParseFloat(val, 32)\n`;
text += `\t\t${name} := float32(${name}32)\n`;
break;
case 'float64':
text += `\t\t${name}, err := strconv.ParseFloat(val, 64)\n`;
break;
case 'int32':
text += `\t\t${name}32, err := strconv.ParseInt(val, 10, 32)\n`;
text += `\t\t${name} := int32(${name}32)\n`;
break;
case 'int64':
text += `\t\t${name}, err := strconv.ParseInt(val, 10, 64)\n`;
break;
default:
throw new CodegenError('InternalError', `unhandled scalar type ${headerResp.type.type}`);
}
break;
case 'string':
text += `\t\t${respObj}.${headerResp.fieldName} = &val\n`;
text += '\t}\n';
return text;
case 'time':
imports.add('time');
switch (headerResp.type.format) {
case 'dateTimeRFC1123':
case 'dateTimeRFC3339':
text += `\t\t${name}, err := time.Parse(${headerResp.type.format === 'dateTimeRFC1123' ? helpers.datetimeRFC1123Format : helpers.datetimeRFC3339Format}, val)\n`;
break;
case 'dateType':
text += `\t\t${name}, err := time.Parse("${helpers.dateFormat}", val)\n`;
break;
case 'timeRFC3339':
text += `\t\t${name}, err := time.Parse("${helpers.timeRFC3339Format}", val)\n`;
break;
case 'timeUnix':
imports.add('strconv');
imports.add('github.com/Azure/azure-sdk-for-go/sdk/azcore/to');
text += '\t\tsec, err := strconv.ParseInt(val, 10, 64)\n';
name = 'to.Ptr(time.Unix(sec, 0))';
byRef = '';
break;
}
}
// NOTE: only cases that required parsing will fall through to here
text += '\t\tif err != nil {\n';
text += `\t\t\treturn ${zeroResp}, err\n`;
text += '\t\t}\n';
text += `\t\t${respObj}.${headerResp.fieldName} = ${byRef}${name}\n`;
text += '\t}\n';
return text;
}
function getZeroReturnValue(method, apiType) {
let returnType = `${method.returns.name}{}`;
if (go.isLROMethod(method)) {
if (apiType === 'api' || apiType === 'op') {
// the api returns a *Poller[T]
// the operation returns an *http.Response
returnType = 'nil';
}
}
return returnType;
}
// Helper function to generate nil checks for a dotted path
function generateNilChecks(path, prefix = 'page') {
const segments = path.split('.');
const checks = [];
for (let i = 0; i < segments.length; i++) {
const currentPath = [prefix, ...segments.slice(0, i + 1)].join('.');
checks.push(`${currentPath} != nil`);
}
return checks.join(' && ');
}
function emitPagerDefinition(method, imports, injectSpans, generateFakes) {
imports.add('context');
let text = `runtime.NewPager(runtime.PagingHandler[${method.returns.name}]{\n`;
text += `\t\tMore: func(page ${method.returns.name}) bool {\n`;
// there is no advancer for single-page pagers
if (method.nextLinkName) {
const nilChecks = generateNilChecks(method.nextLinkName);
text += `\t\t\treturn ${nilChecks} && len(*page.${method.nextLinkName}) > 0\n`;
text += '\t\t},\n';
}
else {
text += '\t\t\treturn false\n';
text += '\t\t},\n';
}
text += `\t\tFetcher: func(ctx context.Context, page *${method.returns.name}) (${method.returns.name}, error) {\n`;
const reqParams = helpers.getCreateRequestParameters(method);
if (generateFakes) {
text += `\t\tctx = context.WithValue(ctx, runtime.CtxAPINameKey{}, "${method.receiver.type.name}.${fixUpMethodName(method)}")\n`;
}
if (method.nextLinkName) {
let nextLinkVar;
if (method.kind === 'pageableMethod') {
text += '\t\t\tnextLink := ""\n';
nextLinkVar = 'nextLink';
text += '\t\t\tif page != nil {\n';
text += `\t\t\t\tnextLink = *page.${method.nextLinkName}\n\t\t\t}\n`;
}
else {
nextLinkVar = `*page.${method.nextLinkName}`;
}
text += `\t\t\tresp, err := runtime.FetcherForNextLink(ctx, client.internal.Pipeline(), ${nextLinkVar}, func(ctx context.Context) (*policy.Request, error) {\n`;
text += `\t\t\t\treturn client.${method.naming.requestMethod}(${reqParams})\n\t\t\t}, `;
// nextPageMethod might be absent in some cases, see https://github.com/Azure/autorest/issues/4393
if (method.nextPageMethod) {
const nextOpParams = helpers.getCreateRequestParametersSig(method.nextPageMethod).split(',');
// keep the parameter names from the name/type tuples and find nextLink param
for (let i = 0; i < nextOpParams.length; ++i) {
const paramName = nextOpParams[i].trim().split(' ')[0];
const paramType = nextOpParams[i].trim().split(' ')[1];
if (paramName.startsWith('next') && paramType === 'string') {
nextOpParams[i] = 'encodedNextLink';
}
else {
nextOpParams[i] = paramName;
}
}
// add a definition for the nextReq func that uses the nextLinkOperation
text += '&runtime.FetcherForNextLinkOptions{\n\t\t\t\tNextReq: func(ctx context.Context, encodedNextLink string) (*policy.Request, error) {\n';
text += `\t\t\t\t\treturn client.${method.nextPageMethod.name}(${nextOpParams.join(', ')})\n\t\t\t\t},\n\t\t\t})\n`;
}
else {
text += 'nil)\n';
}
text += `\t\t\tif err != nil {\n\t\t\t\treturn ${method.returns.name}{}, err\n\t\t\t}\n`;
text += `\t\t\treturn client.${method.naming.responseMethod}(resp)\n`;
text += '\t\t\t},\n';
}
else {
// this is the singular page case, no fetcher helper required
text += `\t\t\treq, err := client.${method.naming.requestMethod}(${reqParams})\n`;
text += '\t\t\tif err != nil {\n';
text += `\t\t\t\treturn ${method.returns.name}{}, err\n`;
text += '\t\t\t}\n';
text += '\t\t\tresp, err := client.internal.Pipeline().Do(req)\n';
text += '\t\t\tif err != nil {\n';
text += `\t\t\t\treturn ${method.returns.name}{}, err\n`;
text += '\t\t\t}\n';
text += '\t\t\tif !runtime.HasStatusCode(resp, http.StatusOK) {\n';
text += `\t\t\t\treturn ${method.returns.name}{}, runtime.NewResponseError(resp)\n`;
text += '\t\t\t}\n';
text += `\t\t\treturn client.${method.naming.responseMethod}(resp)\n`;
text += '\t\t},\n';
}
if (injectSpans) {
text += '\t\tTracer: client.internal.Tracer(),\n';
}
text += '\t})\n';
return text;
}
function genApiVersionDoc(apiVersions) {
if (apiVersions.length === 0) {
return '';
}
return `//\n// Generated from API version ${apiVersions.join(', ')}\n`;
}
function genRespErrorDoc(method) {
if (!(method.returns.result?.kind === 'headAsBooleanResult') && !go.isPageableMethod(method)) {
// when head-as-boolean is enabled, no error is returned for 4xx status codes.
// pager constructors don't return an error
return '// If the operation fails it returns an *azcore.ResponseError type.\n';
}
return '';
}
/**
* returns the receiver definition for a client
*
* @param receiver the receiver for which to emit the definition
* @returns the receiver definition
*/
function getClientReceiverDefinition(receiver) {
return `(${receiver.name} ${receiver.byValue ? '' : '*'}${receiver.type.name})`;
}
function generateOperation(method, imports, injectSpans, generateFakes) {
const params = getAPIParametersSig(method, imports);
const returns = generateReturnsInfo(method, 'op');
let methodName = method.name;
if (method.kind === 'pageableMethod') {
methodName = fixUpMethodName(method);
}
let text = '';
const respErrDoc = genRespErrorDoc(method);
const apiVerDoc = genApiVersionDoc(method.apiVersions);
if (method.docs.summary || method.docs.description) {
text += helpers.formatDocCommentWithPrefix(methodName, method.docs);
}
else if (respErrDoc.length > 0 || apiVerDoc.length > 0) {
// if the method has no doc comment but we're adding other
// doc comments, add an empty method name comment. this preserves
// existing behavior and makes the docs look better overall.
text += `// ${methodName} -\n`;
}
text += respErrDoc;
text += apiVerDoc;
if (go.isLROMethod(method)) {
methodName = method.naming.internalMethod;
}
else {
for (const param of values(helpers.getMethodParameters(method))) {
text += helpers.formatCommentAsBulletItem(param.name, param.docs);
}
}
text += `func ${getClientReceiverDefinition(method.receiver)} ${methodName}(${params}) (${returns.join(', ')}) {\n`;
const reqParams = helpers.getCreateRequestParameters(method);
if (method.kind === 'pageableMethod') {
text += '\treturn ';
text += emitPagerDefinition(method, imports, injectSpans, generateFakes);
text += '}\n\n';
return text;
}
text += '\tvar err error\n';
let operationName = `"${method.receiver.type.name}.${fixUpMethodName(method)}"`;
if (generateFakes && injectSpans) {
text += `\tconst operationName = ${operationName}\n`;
operationName = 'operationName';
}
if (generateFakes) {
text += `\tctx = context.WithValue(ctx, runtime.CtxAPINameKey{}, ${operationName})\n`;
}
if (injectSpans) {
text += `\tctx, endSpan := runtime.StartSpan(ctx, ${operationName}, client.internal.Tracer(), nil)\n`;
text += '\tdefer func() { endSpan(err) }()\n';
}
const zeroResp = getZeroReturnValue(method, 'op');
text += `\treq, err := client.${method.naming.requestMethod}(${reqParams})\n`;
text += '\tif err != nil {\n';
text += `\t\treturn ${zeroResp}, err\n`;
text += '\t}\n';
text += '\thttpResp, err := client.internal.Pipeline().Do(req)\n';
text += '\tif err != nil {\n';
text += `\t\treturn ${zeroResp}, err\n`;
text += '\t}\n';
text += `\tif !runtime.HasStatusCode(httpResp, ${helpers.formatStatusCodes(method.httpStatusCodes)}) {\n`;
text += '\t\terr = runtime.NewResponseError(httpResp)\n';
text += `\t\treturn ${zeroResp}, err\n`;
text += '\t}\n';
// HAB with headers response is handled in protocol responder
if (method.returns.result?.kind === 'headAsBooleanResult' && method.returns.headers.length === 0) {
text += `\treturn ${method.returns.name}{${method.returns.result.fieldName}: httpResp.StatusCode >= 200 && httpResp.StatusCode < 300}, nil\n`;
}
else {
if (go.isLROMethod(method)) {
text += '\treturn httpResp, nil\n';
}
else if (needsResponseHandler(method)) {
// also cheating here as at present the only param to the responder is an http.Response
text += `\tresp, err := client.${method.naming.responseMethod}(httpResp)\n`;
text += '\treturn resp, err\n';
}
else if (method.returns.result?.kind === 'binaryResult') {
text += `\treturn ${method.returns.name}{${method.returns.result.fieldName}: httpResp.Body}, nil\n`;
}
else {
text += `\treturn ${method.returns.name}{}, nil\n`;
}
}
text += '}\n\n';
return text;
}
function createProtocolRequest(azureARM, method, imports) {
let name = method.name;
if (method.kind !== 'nextPageMethod') {
name = method.naming.requestMethod;
}
for (const param of values(method.parameters)) {
if (param.location !== 'method' || !go.isRequiredParameter(param.style)) {
continue;
}
imports.addImportForType(param.type);
}
const returns = ['*policy.Request', 'error'];
let text = `${comment(name, '// ')} creates the ${method.name} request.\n`;
text += `func ${getClientReceiverDefinition(method.receiver)} ${name}(${helpers.getCreateRequestParametersSig(method)}) (${returns.join(', ')}) {\n`;
const hostParams = new Array();
for (const parameter of method.receiver.type.parameters) {
if (parameter.kind === 'uriParam') {
hostParams.push(parameter);
}
}
let hostParam;
if (azureARM) {
hostParam = 'client.internal.Endpoint()';
}
else if (method.receiver.type.instance?.kind === 'templatedHost') {
imports.add('strings');
// we have a templated host
text += `\thost := "${method.receiver.type.instance.path}"\n`;
// get all the host params on the client
for (const hostParam of hostParams) {
text += `\thost = strings.ReplaceAll(host, "{${hostParam.uriPathSegment}}", ${helpers.formatValue(`client.${hostParam.name}`, hostParam.type, imports)})\n`;
}
// check for any method local host params
for (const param of values(method.parameters)) {
if (param.location === 'method' && param.kind === 'uriParam') {
text += `\thost = strings.ReplaceAll(host, "{${param.uriPathSegment}}", ${helpers.formatValue(helpers.getParamName(param), param.type, imports)})\n`;
}
}
hostParam = 'host';
}
else if (hostParams.length === 1) {
// simple parameterized host case
hostParam = 'client.' + hostParams[0].name;
}
else {
throw new CodegenError('InternalError', `no host or endpoint defined for method ${method.receiver.type.name}.${method.name}`);
}
const methodParamGroups = helpers.getMethodParamGroups(method);
const hasPathParams = methodParamGroups.pathParams.length > 0;
// storage needs the client.u to be the source-of-truth for the full path.
// however, swagger requires that all operations specify a path, which is at odds with storage.
// to work around this, storage specifies x-ms-path paths with path params but doesn't
// actually reference the path params (i.e. no params with which to replace the tokens).
// so, if a path contains tokens but there are no path params, skip emitting the path.
const pathStr = method.httpPath;
const pathContainsParms = pathStr.includes('{');
if (hasPathParams || (!pathContainsParms && pathStr.length > 1)) {
// there are path params, or the path doesn't contain tokens and is not "/" so emit it
text += `\turlPath := "${method.httpPath}"\n`;
hostParam = `runtime.JoinPaths(${hostParam}, urlPath)`;
}
// helper to build nil checks for param groups
const emitParamGroupCheck = function (param) {
if (!param.group) {
throw new CodegenError('InternalError', `emitParamGroupCheck called for ungrouped parameter ${param.name}`);
}
let client = '';
if (param.location === 'client') {
client = 'client.';
}
const paramGroupName = uncapitalize(param.group.name);
let optionalParamGroupCheck = `${client}${paramGroupName} != nil && `;
if (param.group.required) {
optionalParamGroupCheck = '';
}
return `\tif ${optionalParamGroupCheck}${client}${paramGroupName}.${capitalize(param.name)} != nil {\n`;
};
if (hasPathParams) {
// swagger defines path params, emit path and replace tokens
imports.add('strings');
// replace path parameters
for (const pp of methodParamGroups.pathParams) {
let paramValue;
let optionalPathSep = false;
if (pp.style !== 'optional') {
// emit check to ensure path param isn't an empty string
if (pp.kind === 'pathScalarParam') {
const choiceIsString = function (type) {
return type.kind === 'constant' && type.type === 'string';
};
// we only need to do this for params that have an underlying type of string
if ((pp.type.kind === 'string' || choiceIsString(pp.type)) && !pp.omitEmptyStringCheck) {
const paramName = helpers.getParamName(pp);
imports.add('errors');
text += `\tif ${paramName} == "" {\n`;
text += `\t\treturn nil, errors.New("parameter ${paramName} cannot be empty")\n`;
text += '\t}\n';
}
}
paramValue = helpers.formatParamValue(pp, imports);
// for collection-based path params, we emit the empty check
// after calling helpers.formatParamValue as that will have the
// var name that contains the slice.
if (pp.kind === 'pathCollectionParam') {
const paramName = helpers.getParamName(pp);
const joinedParamName = `${paramName}Param`;
text += `\t${joinedParamName} := ${paramValue}\n`;
imports.add('errors');
text += `\tif len(${joinedParamName}) == 0 {\n`;
text += `\t\treturn nil, errors.New("parameter ${paramName} cannot be empty")\n`;
text += '\t}\n';
paramValue = joinedParamName;
}
}
else {
// param isn't required, so emit a local var with
// the correct default value, then populate it with
// the optional value when set.
paramValue = `optional${capitalize(pp.name)}`;
text += `\t${paramValue} := ""\n`;
text += emitParamGroupCheck(pp);
text += `\t${paramValue} = ${helpers.formatParamValue(pp, imports)}\n\t}\n`;
// there are two cases for optional path params.
// - /foo/bar/{optional}
// - /foo/bar{/optional}
// for the second case, we need to include a forward slash
if (method.httpPath[method.httpPath.indexOf(`{${pp.pathSegment}}`) - 1] !== '/') {
optionalPathSep = true;
}
}
const emitPathEscape = function () {
if (pp.isEncoded) {
imports.add('net/url');
return `url.PathEscape(${paramValue})`;
}
return paramValue;
};
if (optionalPathSep) {
text += `\tif len(${paramValue}) > 0 {\n`;
text += `\t\t${paramValue} = "/"+${emitPathEscape()}\n`;
text += '\t}\n';
}
else {
paramValue = emitPathEscape();
}
text += `\turlPath = strings.ReplaceAll(urlPath, "{${pp.pathSegment}}", ${paramValue})\n`;
}
}
text += `\treq, err := runtime.NewRequest(ctx, http.Method${capitalize(method.httpMethod)}, ${hostParam})\n`;
text += '\tif err != nil {\n';
text += '\t\treturn nil, err\n';
text += '\t}\n';
// add query parameters
const encodedParams = methodParamGroups.encodedQueryParams;
const unencodedParams = methodParamGroups.unencodedQueryParams;
const emitQueryParam = function (qp, setter) {
let qpText = '';
if (qp.location === 'method' && go.isClientSideDefault(qp.style)) {
qpText = emitClientSideDefault(qp, qp.style, (name, val) => { return `\treqQP.Set(${name}, ${val})`; }, imports);
}
else if (go.isRequiredParameter(qp.style) || go.isLiteralParameter(qp.style) || (qp.location === 'client' && go.isClientSideDefault(qp.style))) {
qpText = `\t${setter}\n`;
}
else if (qp.location === 'client' && !qp.group) {
// global optional param
qpText = `\tif client.${qp.name} != nil {\n`;
qpText += `\t\t${setter}\n`;
qpText += '\t}\n';
}
else {
qpText = emitParamGroupCheck(qp);
qpText += `\t\t${setter}\n`;
qpText += '\t}\n';
}
return qpText;
};
// emit encoded params first
if (encodedParams.length > 0) {
text += '\treqQP := req.Raw().URL.Query()\n';
for (const qp of values(encodedParams.sort((a, b) => { return helpers.sortAscending(a.queryParameter, b.queryParameter); }))) {
let setter;
if (qp.kind === 'queryCollectionParam' && qp.collectionFormat === 'multi') {
setter = `\tfor _, qv := range ${helpers.getParamName(qp)} {\n`;
// emit a type conversion for the qv based on the array's element type
let queryVal;
const arrayQP = qp.type;
switch (arrayQP.elementType.kind) {
case 'constant':
switch (arrayQP.elementType.type) {
case 'string':
queryVal = 'string(qv)';
break;
default:
imports.add('fmt');
queryVal = 'fmt.Sprintf("%d", qv)';
}
break;
case 'string':
queryVal = 'qv';
break;
default:
imports.add('fmt');
queryVal = 'fmt.Sprintf("%v", qv)';
}
setter += `\t\treqQP.Add("${qp.queryParameter}", ${queryVal})\n`;
setter += '\t}';
}
else {
// cannot initialize setter to this value as helpers.formatParamValue() can change imports
setter = `reqQP.Set("${qp.queryParameter}", ${helpers.formatParamValue(qp, imports)})`;
}
text += emitQueryParam(qp, setter);
}
text += '\treq.Raw().URL.RawQuery = reqQP.Encode()\n';
}
// tack on any unencoded params to the end
if (unencodedParams.length > 0) {
if (encodedParams.length > 0) {
text += '\tunencodedParams := []string{req.Raw().URL.RawQuery}\n';
}
else {
text += '\tunencodedParams := []string{}\n';
}
for (const qp of values(unencodedParams.sort((a, b) => { return helpers.sortAscending(a.queryParameter, b.queryParameter); }))) {
let setter;
if (qp.kind === 'queryCollectionParam' && qp.collectionFormat === 'multi') {
setter = `\tfor _, qv := range ${helpers.getParamName(qp)} {\n`;
setter += `\t\tunencodedParams = append(unencodedParams, "${qp.queryParameter}="+qv)\n`;
setter += '\t}';
}
else {
setter = `unencodedParams = append(unencodedParams, "${qp.queryParameter}="+${helpers.formatParamValue(qp, imports)})`;
}
text += emitQueryParam(qp, setter);
}
imports.add('strings');
text += '\treq.Raw().URL.RawQuery = strings.Join(unencodedParams, "&")\n';
}
if (method.kind !== 'nextPageMethod' && method.returns.result?.kind === 'binaryResult') {
// skip auto-body downloading for binary stream responses
text += '\truntime.SkipBodyDownload(req)\n';
}
// add specific request headers
const emitHeaderSet = function (headerParam, prefix) {
if (headerParam.kind === 'headerMapParam') {
let headerText = `${prefix}for k, v := range ${helpers.getParamName(headerParam)} {\n`;
headerText += `${prefix}\tif v != nil {\n`;
headerText += `${prefix}\t\treq.Raw().Header["${headerParam.headerName}"+k] = []string{*v}\n`;
headerText += `${prefix}}\n`;
headerText += `${prefix}}\n`;
return headerText;
}
else if (headerParam.location === 'method' && go.isClientSideDefault(headerParam.style)) {
return emitClientSideDefault(headerParam, headerParam.style, (name, val) => {
return `${prefix}req.Raw().Header[${name}] = []string{${val}}`;
}, imports);
}
else {
return `${prefix}req.Raw().Header["${headerParam.headerName}"] = []string{${helpers.formatParamValue(headerParam, imports)}}\n`;
}
};
let contentType;
for (const param of methodParamGroups.headerParams.sort((a, b) => { return helpers.sortAscending(a.headerName, b.headerName); })) {
if (param.headerName.match(/^content-type$/)) {
// canonicalize content-type as req.SetBody checks for it via its canonicalized name :(
param.headerName = 'Content-Type';
}
if (param.headerName === 'Content-Type' && param.style === 'literal') {
// the content-type header will be set as part of emitSetBodyWithErrCheck
// to handle cases where the body param is optional. we don't want to set
// the content-type if the body is nil.
// we do it like this as tsp specifies content-type while swagger does not.
contentType = helpers.formatParamValue(param, imports);
}
else if (go.isRequiredParameter(param.style) || go.isLiteralParameter(param.style) || go.isClientSideDefault(param.style)) {
text += emitHeaderSet(param, '\t');
}
else if (param.location === 'client' && !param.group) {
// global optional param
text += `\tif client.${param.name} != nil {\n`;
text += emitHeaderSet(param, '\t');
text += '\t}\n';
}
else {
text += emitParamGroupCheck(param);
text += emitHeaderSet(param, '\t\t');
text += '\t}\n';
}
}
// note that these are mutually exclusive
const bodyParam = methodParamGroups.bodyParam;
const formBodyParams = methodParamGroups.formBodyParams;
const multipartBodyParams = methodParamGroups.multipartBodyParams;
const partialBodyParams = methodParamGroups.partialBodyParams;
const emitSetBodyWithErrCheck = function (setBodyParam, contentType) {
let content = `if err := ${setBodyParam}; err != nil {\n\treturn nil, err\n}\n;`;
if (contentType) {
content = `req.Raw().Header["Content-Type"] = []string{${contentType}}\n` + content;
}
return content;
};
if (bodyParam) {
if (bodyParam.bodyFormat === 'JSON' || bodyParam.bodyFormat === 'XML') {
// default to the body param name
let body = helpers.getParamName(bodyParam);
if (bodyParam.type.kind === 'literal') {
// if the value is constant, embed it directly
body = helpers.formatLiteralValue(bodyParam.type, true);
}
else if (bodyParam.bodyFormat === 'XML' && bodyParam.type.kind === 'slice') {
// for XML payloads, create a wrapper type if the payload is an array
imports.add('encoding/xml');
text += '\ttype wrapper struct {\n';
let tagName = go.getTypeDeclaration(bodyParam.type);
if (bodyParam.xml?.name) {
tagName = bodyParam.xml.name;
}
text += `\t\tXMLName xml.Name \`xml:"${tagName}"\`\n`;
const fieldName = capitalize(bodyParam.name);
let tag = go.getTypeDeclaration(bodyParam.type.elementType);
if (bodyParam.type.elementType.kind === 'model' && bodyParam.type.elementType.xml?.name) {
tag = bodyParam.type.elementType.xml.name;
}
text += `\t\t${fieldName} *${go.getTypeDeclaration(bodyParam.type)} \`xml:"${tag}"\`\n`;
text += '\t}\n';
let addr = '&';
if (!go.isRequiredParameter(bodyParam.style) && !bodyParam.byValue) {
addr = '';
}
body = `wrapper{${fieldName}: ${addr}${body}}`;
}
else if (bodyParam.type.kind === 'time' && bodyParam.type.format !== 'dateTimeRFC3339') {
// wrap the body in the internal time type
// no need for dateTimeRFC3339 as the JSON marshaler defaults to that.
body = `${bodyParam.type.format}(${body})`;
}
else if (isArrayOfDateTimeForMarshalling(bodyParam.type)) {
const timeInfo = isArrayOfDateTimeForMarshalling(bodyParam.type);
let elementPtr = '*';