@autorest/go
Version:
AutoRest Go Generator
1,026 lines • 64.3 kB
JavaScript
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as go from '../../codemodel.go/src/index.js';
import { capitalize, comment, uncapitalize } from '@azure-tools/codegen';
import { values } from '@azure-tools/linq';
import * as helpers from './helpers.js';
import { ImportManager } from './imports.js';
import { CodegenError } from './errors.js';
// represents the generated content for an operation group
export class OperationGroupContent {
name;
content;
constructor(name, content) {
this.name = name;
this.content = content;
}
}
// Creates the content for all <operation>.go files
export async function generateOperations(codeModel) {
// generate protocol operations
const operations = new Array();
if (codeModel.clients.length === 0) {
return operations;
}
const azureARM = codeModel.type === 'azure-arm';
for (const client of codeModel.clients) {
// the list of packages to import
const imports = new ImportManager();
if (client.methods.length > 0) {
// add standard imports for clients with methods.
// clients that are purely hierarchical (i.e. having no APIs) won't need them.
imports.add('net/http');
imports.add('github.com/Azure/azure-sdk-for-go/sdk/azcore/policy');
imports.add('github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime');
}
let clientPkg = 'azcore';
if (azureARM) {
clientPkg = 'arm';
imports.add('github.com/Azure/azure-sdk-for-go/sdk/azcore/arm');
client.constructors.push(createARMClientConstructor(client, imports));
}
else {
imports.add('github.com/Azure/azure-sdk-for-go/sdk/azcore');
}
// generate client type
let clientText = helpers.formatDocComment(client.docs);
clientText += '// Don\'t use this type directly, use ';
if (client.constructors.length === 1) {
clientText += `${client.constructors[0].name}() instead.\n`;
}
else if (client.parent) {
// find the accessor method
let accessorMethod;
for (const clientAccessor of client.parent.clientAccessors) {
if (clientAccessor.subClient === client) {
accessorMethod = clientAccessor.name;
break;
}
}
if (!accessorMethod) {
throw new CodegenError('InternalError', `didn't find accessor method for client ${client.name} on parent client ${client.parent.name}`);
}
clientText += `[${client.parent.name}.${accessorMethod}] instead.\n`;
}
else {
clientText += 'a constructor function instead.\n';
}
clientText += `type ${client.name} struct {\n`;
clientText += `\tinternal *${clientPkg}.Client\n`;
// check for any optional host params
const optionalParams = new Array();
const isParamPointer = function (param) {
// for client params, only optional and flag types are passed by pointer
return param.style === 'flag' || param.style === 'optional';
};
// now emit any client params (non parameterized host params case)
if (client.parameters.length > 0) {
const addedGroups = new Set();
for (const clientParam of values(client.parameters)) {
if (go.isLiteralParameter(clientParam)) {
continue;
}
if (clientParam.group) {
if (!addedGroups.has(clientParam.group.groupName)) {
clientText += `\t${uncapitalize(clientParam.group.groupName)} ${!isParamPointer(clientParam) ? '' : '*'}${clientParam.group.groupName}\n`;
addedGroups.add(clientParam.group.groupName);
}
continue;
}
clientText += `\t${clientParam.name} `;
if (!isParamPointer(clientParam)) {
clientText += `${go.getTypeDeclaration(clientParam.type)}\n`;
}
else {
clientText += `${helpers.formatParameterTypeName(clientParam)}\n`;
}
if (!go.isRequiredParameter(clientParam)) {
optionalParams.push(clientParam);
}
}
}
// end of client definition
clientText += '}\n\n';
if (azureARM && optionalParams.length > 0) {
throw new CodegenError('UnsupportedTsp', 'optional client parameters for ARM is not supported');
}
// generate client constructors
clientText += generateConstructors(azureARM, client, imports);
// generate client accessors and operations
let opText = '';
for (const clientAccessor of client.clientAccessors) {
opText += `// ${clientAccessor.name} creates a new instance of [${clientAccessor.subClient.name}].\n`;
opText += `func (client *${client.name}) ${clientAccessor.name}() *${clientAccessor.subClient.name} {\n`;
opText += `\treturn &${clientAccessor.subClient.name}{\n`;
opText += '\t\tinternal: client.internal,\n';
// propagate all client params
for (const param of client.parameters) {
opText += `\t\t${param.name}: client.${param.name},\n`;
}
opText += '\t}\n}\n\n';
}
const nextPageMethods = new Array();
for (const method of client.methods) {
// protocol creation can add imports to the list so
// it must be done before the imports are written out
if (go.isLROMethod(method)) {
// generate Begin method
opText += generateLROBeginMethod(client, method, imports, codeModel.options.injectSpans, codeModel.options.generateFakes);
}
opText += generateOperation(client, method, imports, codeModel.options.injectSpans, codeModel.options.generateFakes);
opText += createProtocolRequest(azureARM, client, method, imports);
if (method.kind !== 'lroMethod') {
// LRO responses are handled elsewhere, with the exception of pageable LROs
opText += createProtocolResponse(client, method, imports);
}
if ((method.kind === 'lroPageableMethod' || method.kind === 'pageableMethod') && method.nextPageMethod && !nextPageMethods.includes(method.nextPageMethod)) {
// track the next page methods to generate as multiple operations can use the same next page operation
nextPageMethods.push(method.nextPageMethod);
}
}
for (const method of nextPageMethods) {
opText += createProtocolRequest(azureARM, client, method, imports);
}
// stitch it all together
let text = helpers.contentPreamble(codeModel);
text += imports.text();
text += clientText;
text += opText;
operations.push(new OperationGroupContent(client.name, text));
}
return operations;
}
// generates all modeled client constructors
function generateConstructors(azureARM, client, imports) {
if (client.constructors.length === 0) {
return '';
}
let ctorText = '';
for (const constructor of client.constructors) {
const ctorParams = new Array();
const paramDocs = new Array();
constructor.parameters.sort(helpers.sortParametersByRequired);
for (const ctorParam of constructor.parameters) {
imports.addImportForType(ctorParam.type);
ctorParams.push(`${ctorParam.name} ${helpers.formatParameterTypeName(ctorParam)}`);
if (ctorParam.docs.summary || ctorParam.docs.description) {
paramDocs.push(helpers.formatCommentAsBulletItem(ctorParam.name, ctorParam.docs));
}
}
// add client options last
ctorParams.push(`${client.options.name} ${helpers.formatParameterTypeName(client.options)}`);
paramDocs.push(helpers.formatCommentAsBulletItem(client.options.name, client.options.docs));
ctorText += `// ${constructor.name} creates a new instance of ${client.name} with the specified values.\n`;
for (const doc of paramDocs) {
ctorText += doc;
}
ctorText += `func ${constructor.name}(${ctorParams.join(', ')}) (*${client.name}, error) {\n`;
let clientType = 'azcore';
if (azureARM) {
clientType = 'arm';
}
ctorText += `\tcl, err := ${clientType}.NewClient(moduleName, moduleVersion, credential, options)\n`;
ctorText += '\tif err != nil {\n';
ctorText += '\t\treturn nil, err\n';
ctorText += '\t}\n';
// construct client literal
ctorText += `\tclient := &${client.name}{\n`;
for (const parameter of values(client.parameters)) {
// each client field will have a matching parameter with the same name
ctorText += `\t\t${parameter.name}: ${parameter.name},\n`;
}
ctorText += '\tinternal: cl,\n';
ctorText += '\t}\n';
ctorText += '\treturn client, nil\n';
ctorText += '}\n\n';
}
return ctorText;
}
// creates a modeled constructor for an ARM client
function createARMClientConstructor(client, imports) {
imports.add('github.com/Azure/azure-sdk-for-go/sdk/azcore/arm');
const ctor = new go.Constructor(`New${client.name}`);
// add any modeled parameter first, which should only be the subscriptionID, then add TokenCredential
for (const param of client.parameters) {
ctor.parameters.push(param);
}
const tokenCredParam = new go.Parameter('credential', new go.QualifiedType('TokenCredential', 'github.com/Azure/azure-sdk-for-go/sdk/azcore'), 'required', true, 'client');
tokenCredParam.docs.summary = 'used to authorize requests. Usually a credential from azidentity.';
ctor.parameters.push(tokenCredParam);
return ctor;
}
// use this to generate the code that will help process values returned in response headers
function formatHeaderResponseValue(headerResp, imports, respObj, zeroResp) {
// dictionaries are handled slightly different so we do that first
if (headerResp.kind === 'headerMapResponse') {
imports.add('github.com/Azure/azure-sdk-for-go/sdk/azcore/to');
imports.add('strings');
const headerPrefix = headerResp.headerName;
let text = '\tfor hh := range resp.Header {\n';
text += `\t\tif len(hh) > len("${headerPrefix}") && strings.EqualFold(hh[:len("${headerPrefix}")], "${headerPrefix}") {\n`;
text += `\t\t\tif ${respObj}.${headerResp.fieldName} == nil {\n`;
text += `\t\t\t\t${respObj}.${headerResp.fieldName} = map[string]*string{}\n`;
text += '\t\t\t}\n';
text += `\t\t\t${respObj}.${headerResp.fieldName}[hh[len("${headerPrefix}"):]] = to.Ptr(resp.Header.Get(hh))\n`;
text += '\t\t}\n';
text += '\t}\n';
return text;
}
let text = `\tif val := resp.Header.Get("${headerResp.headerName}"); val != "" {\n`;
let name = uncapitalize(headerResp.fieldName);
let byRef = '&';
switch (headerResp.type.kind) {
case 'constant':
text += `\t\t${respObj}.${headerResp.fieldName} = (*${headerResp.type.name})(&val)\n`;
text += '\t}\n';
return text;
case 'encodedBytes':
// a base-64 encoded value in string format
imports.add('encoding/base64');
text += `\t\t${name}, err := base64.${helpers.formatBytesEncoding(headerResp.type.encoding)}Encoding.DecodeString(val)\n`;
byRef = '';
break;
case 'literal':
text += `\t\t${respObj}.${headerResp.fieldName} = &val\n`;
text += '\t}\n';
return text;
case 'scalar':
imports.add('strconv');
switch (headerResp.type.type) {
case 'bool':
text += `\t\t${name}, err := strconv.ParseBool(val)\n`;
break;
case 'float32':
text += `\t\t${name}32, err := strconv.ParseFloat(val, 32)\n`;
text += `\t\t${name} := float32(${name}32)\n`;
break;
case 'float64':
text += `\t\t${name}, err := strconv.ParseFloat(val, 64)\n`;
break;
case 'int32':
text += `\t\t${name}32, err := strconv.ParseInt(val, 10, 32)\n`;
text += `\t\t${name} := int32(${name}32)\n`;
break;
case 'int64':
text += `\t\t${name}, err := strconv.ParseInt(val, 10, 64)\n`;
break;
default:
throw new CodegenError('InternalError', `unhandled scalar type ${headerResp.type.type}`);
}
break;
case 'string':
text += `\t\t${respObj}.${headerResp.fieldName} = &val\n`;
text += '\t}\n';
return text;
case 'time':
imports.add('time');
switch (headerResp.type.format) {
case 'dateTimeRFC1123':
case 'dateTimeRFC3339':
text += `\t\t${name}, err := time.Parse(${headerResp.type.format === 'dateTimeRFC1123' ? helpers.datetimeRFC1123Format : helpers.datetimeRFC3339Format}, val)\n`;
break;
case 'dateType':
text += `\t\t${name}, err := time.Parse("${helpers.dateFormat}", val)\n`;
break;
case 'timeRFC3339':
text += `\t\t${name}, err := time.Parse("${helpers.timeRFC3339Format}", val)\n`;
break;
case 'timeUnix':
imports.add('strconv');
imports.add('github.com/Azure/azure-sdk-for-go/sdk/azcore/to');
text += '\t\tsec, err := strconv.ParseInt(val, 10, 64)\n';
name = 'to.Ptr(time.Unix(sec, 0))';
byRef = '';
break;
}
}
// NOTE: only cases that required parsing will fall through to here
text += '\t\tif err != nil {\n';
text += `\t\t\treturn ${zeroResp}, err\n`;
text += '\t\t}\n';
text += `\t\t${respObj}.${headerResp.fieldName} = ${byRef}${name}\n`;
text += '\t}\n';
return text;
}
function getZeroReturnValue(method, apiType) {
let returnType = `${method.responseEnvelope.name}{}`;
if (go.isLROMethod(method)) {
if (apiType === 'api' || apiType === 'op') {
// the api returns a *Poller[T]
// the operation returns an *http.Response
returnType = 'nil';
}
}
return returnType;
}
// Helper function to generate nil checks for a dotted path
function generateNilChecks(path, prefix = 'page') {
const segments = path.split('.');
const checks = [];
for (let i = 0; i < segments.length; i++) {
const currentPath = [prefix, ...segments.slice(0, i + 1)].join('.');
checks.push(`${currentPath} != nil`);
}
return checks.join(' && ');
}
function emitPagerDefinition(client, method, imports, injectSpans, generateFakes) {
imports.add('context');
let text = `runtime.NewPager(runtime.PagingHandler[${method.responseEnvelope.name}]{\n`;
text += `\t\tMore: func(page ${method.responseEnvelope.name}) bool {\n`;
// there is no advancer for single-page pagers
if (method.nextLinkName) {
const nilChecks = generateNilChecks(method.nextLinkName);
text += `\t\t\treturn ${nilChecks} && len(*page.${method.nextLinkName}) > 0\n`;
text += '\t\t},\n';
}
else {
text += '\t\t\treturn false\n';
text += '\t\t},\n';
}
text += `\t\tFetcher: func(ctx context.Context, page *${method.responseEnvelope.name}) (${method.responseEnvelope.name}, error) {\n`;
const reqParams = helpers.getCreateRequestParameters(method);
if (generateFakes) {
text += `\t\tctx = context.WithValue(ctx, runtime.CtxAPINameKey{}, "${client.name}.${fixUpMethodName(method)}")\n`;
}
if (method.nextLinkName) {
let nextLinkVar;
if (method.kind === 'pageableMethod') {
text += '\t\t\tnextLink := ""\n';
nextLinkVar = 'nextLink';
text += '\t\t\tif page != nil {\n';
text += `\t\t\t\tnextLink = *page.${method.nextLinkName}\n\t\t\t}\n`;
}
else {
nextLinkVar = `*page.${method.nextLinkName}`;
}
text += `\t\t\tresp, err := runtime.FetcherForNextLink(ctx, client.internal.Pipeline(), ${nextLinkVar}, func(ctx context.Context) (*policy.Request, error) {\n`;
text += `\t\t\t\treturn client.${method.naming.requestMethod}(${reqParams})\n\t\t\t}, `;
// nextPageMethod might be absent in some cases, see https://github.com/Azure/autorest/issues/4393
if (method.nextPageMethod) {
const nextOpParams = helpers.getCreateRequestParametersSig(method.nextPageMethod).split(',');
// keep the parameter names from the name/type tuples and find nextLink param
for (let i = 0; i < nextOpParams.length; ++i) {
const paramName = nextOpParams[i].trim().split(' ')[0];
const paramType = nextOpParams[i].trim().split(' ')[1];
if (paramName.startsWith('next') && paramType === 'string') {
nextOpParams[i] = 'encodedNextLink';
}
else {
nextOpParams[i] = paramName;
}
}
// add a definition for the nextReq func that uses the nextLinkOperation
text += '&runtime.FetcherForNextLinkOptions{\n\t\t\t\tNextReq: func(ctx context.Context, encodedNextLink string) (*policy.Request, error) {\n';
text += `\t\t\t\t\treturn client.${method.nextPageMethod.name}(${nextOpParams.join(', ')})\n\t\t\t\t},\n\t\t\t})\n`;
}
else {
text += 'nil)\n';
}
text += `\t\t\tif err != nil {\n\t\t\t\treturn ${method.responseEnvelope.name}{}, err\n\t\t\t}\n`;
text += `\t\t\treturn client.${method.naming.responseMethod}(resp)\n`;
text += '\t\t\t},\n';
}
else {
// this is the singular page case, no fetcher helper required
text += `\t\t\treq, err := client.${method.naming.requestMethod}(${reqParams})\n`;
text += '\t\t\tif err != nil {\n';
text += `\t\t\t\treturn ${method.responseEnvelope.name}{}, err\n`;
text += '\t\t\t}\n';
text += '\t\t\tresp, err := client.internal.Pipeline().Do(req)\n';
text += '\t\t\tif err != nil {\n';
text += `\t\t\t\treturn ${method.responseEnvelope.name}{}, err\n`;
text += '\t\t\t}\n';
text += '\t\t\tif !runtime.HasStatusCode(resp, http.StatusOK) {\n';
text += `\t\t\t\treturn ${method.responseEnvelope.name}{}, runtime.NewResponseError(resp)\n`;
text += '\t\t\t}\n';
text += `\t\t\treturn client.${method.naming.responseMethod}(resp)\n`;
text += '\t\t},\n';
}
if (injectSpans) {
text += '\t\tTracer: client.internal.Tracer(),\n';
}
text += '\t})\n';
return text;
}
function genApiVersionDoc(apiVersions) {
if (apiVersions.length === 0) {
return '';
}
return `//\n// Generated from API version ${apiVersions.join(', ')}\n`;
}
function genRespErrorDoc(method) {
if (!(method.responseEnvelope.result?.kind === 'headAsBooleanResult') && !go.isPageableMethod(method)) {
// when head-as-boolean is enabled, no error is returned for 4xx status codes.
// pager constructors don't return an error
return '// If the operation fails it returns an *azcore.ResponseError type.\n';
}
return '';
}
function generateOperation(client, method, imports, injectSpans, generateFakes) {
const params = getAPIParametersSig(method, imports);
const returns = generateReturnsInfo(method, 'op');
let methodName = method.name;
if (method.kind === 'pageableMethod') {
methodName = fixUpMethodName(method);
}
let text = '';
const respErrDoc = genRespErrorDoc(method);
const apiVerDoc = genApiVersionDoc(method.apiVersions);
if (method.docs.summary || method.docs.description) {
text += helpers.formatDocCommentWithPrefix(methodName, method.docs);
}
else if (respErrDoc.length > 0 || apiVerDoc.length > 0) {
// if the method has no doc comment but we're adding other
// doc comments, add an empty method name comment. this preserves
// existing behavior and makes the docs look better overall.
text += `// ${methodName} -\n`;
}
text += respErrDoc;
text += apiVerDoc;
if (go.isLROMethod(method)) {
methodName = method.naming.internalMethod;
}
else {
for (const param of values(helpers.getMethodParameters(method))) {
text += helpers.formatCommentAsBulletItem(param.name, param.docs);
}
}
text += `func (client *${client.name}) ${methodName}(${params}) (${returns.join(', ')}) {\n`;
const reqParams = helpers.getCreateRequestParameters(method);
if (method.kind === 'pageableMethod') {
text += '\treturn ';
text += emitPagerDefinition(client, method, imports, injectSpans, generateFakes);
text += '}\n\n';
return text;
}
text += '\tvar err error\n';
let operationName = `"${client.name}.${fixUpMethodName(method)}"`;
if (generateFakes && injectSpans) {
text += `\tconst operationName = ${operationName}\n`;
operationName = 'operationName';
}
if (generateFakes) {
text += `\tctx = context.WithValue(ctx, runtime.CtxAPINameKey{}, ${operationName})\n`;
}
if (injectSpans) {
text += `\tctx, endSpan := runtime.StartSpan(ctx, ${operationName}, client.internal.Tracer(), nil)\n`;
text += '\tdefer func() { endSpan(err) }()\n';
}
const zeroResp = getZeroReturnValue(method, 'op');
text += `\treq, err := client.${method.naming.requestMethod}(${reqParams})\n`;
text += '\tif err != nil {\n';
text += `\t\treturn ${zeroResp}, err\n`;
text += '\t}\n';
text += '\thttpResp, err := client.internal.Pipeline().Do(req)\n';
text += '\tif err != nil {\n';
text += `\t\treturn ${zeroResp}, err\n`;
text += '\t}\n';
text += `\tif !runtime.HasStatusCode(httpResp, ${helpers.formatStatusCodes(method.httpStatusCodes)}) {\n`;
text += '\t\terr = runtime.NewResponseError(httpResp)\n';
text += `\t\treturn ${zeroResp}, err\n`;
text += '\t}\n';
// HAB with headers response is handled in protocol responder
if (method.responseEnvelope.result?.kind === 'headAsBooleanResult' && method.responseEnvelope.headers.length === 0) {
text += `\treturn ${method.responseEnvelope.name}{${method.responseEnvelope.result.fieldName}: httpResp.StatusCode >= 200 && httpResp.StatusCode < 300}, nil\n`;
}
else {
if (go.isLROMethod(method)) {
text += '\treturn httpResp, nil\n';
}
else if (needsResponseHandler(method)) {
// also cheating here as at present the only param to the responder is an http.Response
text += `\tresp, err := client.${method.naming.responseMethod}(httpResp)\n`;
text += '\treturn resp, err\n';
}
else if (method.responseEnvelope.result?.kind === 'binaryResult') {
text += `\treturn ${method.responseEnvelope.name}{${method.responseEnvelope.result.fieldName}: httpResp.Body}, nil\n`;
}
else {
text += `\treturn ${method.responseEnvelope.name}{}, nil\n`;
}
}
text += '}\n\n';
return text;
}
function createProtocolRequest(azureARM, client, method, imports) {
let name = method.name;
if (method.kind !== 'nextPageMethod') {
name = method.naming.requestMethod;
}
for (const param of values(method.parameters)) {
if (param.location !== 'method' || !go.isRequiredParameter(param)) {
continue;
}
imports.addImportForType(param.type);
}
const returns = ['*policy.Request', 'error'];
let text = `${comment(name, '// ')} creates the ${method.name} request.\n`;
text += `func (client *${client.name}) ${name}(${helpers.getCreateRequestParametersSig(method)}) (${returns.join(', ')}) {\n`;
const hostParams = new Array();
for (const parameter of client.parameters) {
if (parameter.kind === 'uriParam') {
hostParams.push(parameter);
}
}
let hostParam;
if (azureARM) {
hostParam = 'client.internal.Endpoint()';
}
else if (client.templatedHost) {
imports.add('strings');
// we have a templated host
text += `\thost := "${client.templatedHost}"\n`;
// get all the host params on the client
for (const hostParam of hostParams) {
text += `\thost = strings.ReplaceAll(host, "{${hostParam.uriPathSegment}}", ${helpers.formatValue(`client.${hostParam.name}`, hostParam.type, imports)})\n`;
}
// check for any method local host params
for (const param of values(method.parameters)) {
if (param.location === 'method' && param.kind === 'uriParam') {
text += `\thost = strings.ReplaceAll(host, "{${param.uriPathSegment}}", ${helpers.formatValue(helpers.getParamName(param), param.type, imports)})\n`;
}
}
hostParam = 'host';
}
else if (hostParams.length === 1) {
// simple parameterized host case
hostParam = 'client.' + hostParams[0].name;
}
else {
throw new CodegenError('InternalError', `no host or endpoint defined for method ${client.name}.${method.name}`);
}
const methodParamGroups = helpers.getMethodParamGroups(method);
const hasPathParams = methodParamGroups.pathParams.length > 0;
// storage needs the client.u to be the source-of-truth for the full path.
// however, swagger requires that all operations specify a path, which is at odds with storage.
// to work around this, storage specifies x-ms-path paths with path params but doesn't
// actually reference the path params (i.e. no params with which to replace the tokens).
// so, if a path contains tokens but there are no path params, skip emitting the path.
const pathStr = method.httpPath;
const pathContainsParms = pathStr.includes('{');
if (hasPathParams || (!pathContainsParms && pathStr.length > 1)) {
// there are path params, or the path doesn't contain tokens and is not "/" so emit it
text += `\turlPath := "${method.httpPath}"\n`;
hostParam = `runtime.JoinPaths(${hostParam}, urlPath)`;
}
if (hasPathParams) {
// swagger defines path params, emit path and replace tokens
imports.add('strings');
// replace path parameters
for (const pp of methodParamGroups.pathParams) {
// emit check to ensure path param isn't an empty string. we only need
// to do this for params that have an underlying type of string.
const choiceIsString = function (type) {
return type.kind === 'constant' && type.type === 'string';
};
// TODO: https://github.com/Azure/autorest.go/issues/1593
if (pp.kind === 'pathScalarParam' && ((pp.type.kind === 'string' || choiceIsString(pp.type)) && pp.isEncoded)) {
const paramName = helpers.getParamName(pp);
imports.add('errors');
text += `\tif ${paramName} == "" {\n`;
text += `\t\treturn nil, errors.New("parameter ${paramName} cannot be empty")\n`;
text += '\t}\n';
}
let paramValue = helpers.formatParamValue(pp, imports);
if (pp.isEncoded) {
imports.add('net/url');
paramValue = `url.PathEscape(${helpers.formatParamValue(pp, imports)})`;
}
text += `\turlPath = strings.ReplaceAll(urlPath, "{${pp.pathSegment}}", ${paramValue})\n`;
}
}
text += `\treq, err := runtime.NewRequest(ctx, http.Method${capitalize(method.httpMethod)}, ${hostParam})\n`;
text += '\tif err != nil {\n';
text += '\t\treturn nil, err\n';
text += '\t}\n';
// helper to build nil checks for param groups
const emitParamGroupCheck = function (param) {
if (!param.group) {
throw new CodegenError('InternalError', `emitParamGroupCheck called for ungrouped parameter ${param.name}`);
}
let client = '';
if (param.location === 'client') {
client = 'client.';
}
const paramGroupName = uncapitalize(param.group.name);
let optionalParamGroupCheck = `${client}${paramGroupName} != nil && `;
if (param.group.required) {
optionalParamGroupCheck = '';
}
return `\tif ${optionalParamGroupCheck}${client}${paramGroupName}.${capitalize(param.name)} != nil {\n`;
};
// add query parameters
const encodedParams = methodParamGroups.encodedQueryParams;
const unencodedParams = methodParamGroups.unencodedQueryParams;
const emitQueryParam = function (qp, setter) {
let qpText = '';
if (qp.location === 'method' && go.isClientSideDefault(qp.style)) {
qpText = emitClientSideDefault(qp, qp.style, (name, val) => { return `\treqQP.Set(${name}, ${val})`; }, imports);
}
else if (go.isRequiredParameter(qp) || go.isLiteralParameter(qp) || (qp.location === 'client' && go.isClientSideDefault(qp.style))) {
qpText = `\t${setter}\n`;
}
else if (qp.location === 'client' && !qp.group) {
// global optional param
qpText = `\tif client.${qp.name} != nil {\n`;
qpText += `\t\t${setter}\n`;
qpText += '\t}\n';
}
else {
qpText = emitParamGroupCheck(qp);
qpText += `\t\t${setter}\n`;
qpText += '\t}\n';
}
return qpText;
};
// emit encoded params first
if (encodedParams.length > 0) {
text += '\treqQP := req.Raw().URL.Query()\n';
for (const qp of values(encodedParams.sort((a, b) => { return helpers.sortAscending(a.queryParameter, b.queryParameter); }))) {
let setter;
if (qp.kind === 'queryCollectionParam' && qp.collectionFormat === 'multi') {
setter = `\tfor _, qv := range ${helpers.getParamName(qp)} {\n`;
// emit a type conversion for the qv based on the array's element type
let queryVal;
const arrayQP = qp.type;
switch (arrayQP.elementType.kind) {
case 'constant':
switch (arrayQP.elementType.type) {
case 'string':
queryVal = 'string(qv)';
break;
default:
imports.add('fmt');
queryVal = 'fmt.Sprintf("%d", qv)';
}
break;
case 'string':
queryVal = 'qv';
break;
default:
imports.add('fmt');
queryVal = 'fmt.Sprintf("%v", qv)';
}
setter += `\t\treqQP.Add("${qp.queryParameter}", ${queryVal})\n`;
setter += '\t}';
}
else {
// cannot initialize setter to this value as helpers.formatParamValue() can change imports
setter = `reqQP.Set("${qp.queryParameter}", ${helpers.formatParamValue(qp, imports)})`;
}
text += emitQueryParam(qp, setter);
}
text += '\treq.Raw().URL.RawQuery = reqQP.Encode()\n';
}
// tack on any unencoded params to the end
if (unencodedParams.length > 0) {
if (encodedParams.length > 0) {
text += '\tunencodedParams := []string{req.Raw().URL.RawQuery}\n';
}
else {
text += '\tunencodedParams := []string{}\n';
}
for (const qp of values(unencodedParams.sort((a, b) => { return helpers.sortAscending(a.queryParameter, b.queryParameter); }))) {
let setter;
if (qp.kind === 'queryCollectionParam' && qp.collectionFormat === 'multi') {
setter = `\tfor _, qv := range ${helpers.getParamName(qp)} {\n`;
setter += `\t\tunencodedParams = append(unencodedParams, "${qp.queryParameter}="+qv)\n`;
setter += '\t}';
}
else {
setter = `unencodedParams = append(unencodedParams, "${qp.queryParameter}="+${helpers.formatParamValue(qp, imports)})`;
}
text += emitQueryParam(qp, setter);
}
imports.add('strings');
text += '\treq.Raw().URL.RawQuery = strings.Join(unencodedParams, "&")\n';
}
if (method.kind !== 'nextPageMethod' && method.responseEnvelope.result?.kind === 'binaryResult') {
// skip auto-body downloading for binary stream responses
text += '\truntime.SkipBodyDownload(req)\n';
}
// add specific request headers
const emitHeaderSet = function (headerParam, prefix) {
if (headerParam.kind === 'headerMapParam') {
let headerText = `${prefix}for k, v := range ${helpers.getParamName(headerParam)} {\n`;
headerText += `${prefix}\tif v != nil {\n`;
headerText += `${prefix}\t\treq.Raw().Header["${headerParam.headerName}"+k] = []string{*v}\n`;
headerText += `${prefix}}\n`;
headerText += `${prefix}}\n`;
return headerText;
}
else if (headerParam.location === 'method' && go.isClientSideDefault(headerParam.style)) {
return emitClientSideDefault(headerParam, headerParam.style, (name, val) => {
return `${prefix}req.Raw().Header[${name}] = []string{${val}}`;
}, imports);
}
else {
return `${prefix}req.Raw().Header["${headerParam.headerName}"] = []string{${helpers.formatParamValue(headerParam, imports)}}\n`;
}
};
let contentType;
for (const param of methodParamGroups.headerParams.sort((a, b) => { return helpers.sortAscending(a.headerName, b.headerName); })) {
if (param.headerName.match(/^content-type$/)) {
// canonicalize content-type as req.SetBody checks for it via its canonicalized name :(
param.headerName = 'Content-Type';
}
if (param.headerName === 'Content-Type' && param.style === 'literal') {
// the content-type header will be set as part of emitSetBodyWithErrCheck
// to handle cases where the body param is optional. we don't want to set
// the content-type if the body is nil.
// we do it like this as tsp specifies content-type while swagger does not.
contentType = helpers.formatParamValue(param, imports);
}
else if (go.isRequiredParameter(param) || go.isLiteralParameter(param) || go.isClientSideDefault(param.style)) {
text += emitHeaderSet(param, '\t');
}
else if (param.location === 'client' && !param.group) {
// global optional param
text += `\tif client.${param.name} != nil {\n`;
text += emitHeaderSet(param, '\t');
text += '\t}\n';
}
else {
text += emitParamGroupCheck(param);
text += emitHeaderSet(param, '\t\t');
text += '\t}\n';
}
}
// note that these are mutually exclusive
const bodyParam = methodParamGroups.bodyParam;
const formBodyParams = methodParamGroups.formBodyParams;
const multipartBodyParams = methodParamGroups.multipartBodyParams;
const partialBodyParams = methodParamGroups.partialBodyParams;
const emitSetBodyWithErrCheck = function (setBodyParam, contentType) {
let content = `if err := ${setBodyParam}; err != nil {\n\treturn nil, err\n}\n;`;
if (contentType) {
content = `req.Raw().Header["Content-Type"] = []string{${contentType}}\n` + content;
}
return content;
};
if (bodyParam) {
if (bodyParam.bodyFormat === 'JSON' || bodyParam.bodyFormat === 'XML') {
// default to the body param name
let body = helpers.getParamName(bodyParam);
if (bodyParam.type.kind === 'literal') {
// if the value is constant, embed it directly
body = helpers.formatLiteralValue(bodyParam.type, true);
}
else if (bodyParam.bodyFormat === 'XML' && bodyParam.type.kind === 'slice') {
// for XML payloads, create a wrapper type if the payload is an array
imports.add('encoding/xml');
text += '\ttype wrapper struct {\n';
let tagName = go.getTypeDeclaration(bodyParam.type);
if (bodyParam.xml?.name) {
tagName = bodyParam.xml.name;
}
text += `\t\tXMLName xml.Name \`xml:"${tagName}"\`\n`;
const fieldName = capitalize(bodyParam.name);
let tag = go.getTypeDeclaration(bodyParam.type.elementType);
if (bodyParam.type.elementType.kind === 'model' && bodyParam.type.elementType.xml?.name) {
tag = bodyParam.type.elementType.xml.name;
}
text += `\t\t${fieldName} *${go.getTypeDeclaration(bodyParam.type)} \`xml:"${tag}"\`\n`;
text += '\t}\n';
let addr = '&';
if (!go.isRequiredParameter(bodyParam) && !bodyParam.byValue) {
addr = '';
}
body = `wrapper{${fieldName}: ${addr}${body}}`;
}
else if (bodyParam.type.kind === 'time' && bodyParam.type.format !== 'dateTimeRFC3339') {
// wrap the body in the internal time type
// no need for dateTimeRFC3339 as the JSON marshaler defaults to that.
body = `${bodyParam.type.format}(${body})`;
}
else if (isArrayOfDateTimeForMarshalling(bodyParam.type)) {
const timeInfo = isArrayOfDateTimeForMarshalling(bodyParam.type);
let elementPtr = '*';
if (timeInfo?.elemByVal) {
elementPtr = '';
}
text += `\taux := make([]${elementPtr}${timeInfo?.format}, len(${body}))\n`;
text += `\tfor i := 0; i < len(${body}); i++ {\n`;
text += `\t\taux[i] = (${elementPtr}${timeInfo?.format})(${body}[i])\n`;
text += '\t}\n';
body = 'aux';
}
else if (isMapOfDateTime(bodyParam.type)) {
const timeType = isMapOfDateTime(bodyParam.type);
text += `\taux := map[string]*${timeType}{}\n`;
text += `\tfor k, v := range ${body} {\n`;
text += `\t\taux[k] = (*${timeType})(v)\n`;
text += '\t}\n';
body = 'aux';
}
let setBody = `runtime.MarshalAs${getMediaFormat(bodyParam.type, bodyParam.bodyFormat, `req, ${body}`)}`;
if (bodyParam.type.kind === 'rawJSON') {
imports.add('bytes');
imports.add('github.com/Azure/azure-sdk-for-go/sdk/azcore/streaming');
setBody = `req.SetBody(streaming.NopCloser(bytes.NewReader(${body})), "application/${bodyParam.bodyFormat.toLowerCase()}")`;
}
if (go.isRequiredParameter(bodyParam) || go.isLiteralParameter(bodyParam)) {
text += `\t${emitSetBodyWithErrCheck(setBody, contentType)}`;
text += '\treturn req, nil\n';
}
else {
text += emitParamGroupCheck(bodyParam);
text += `\t${emitSetBodyWithErrCheck(setBody, contentType)}`;
text += '\t\treturn req, nil\n';
text += '\t}\n';
text += '\treturn req, nil\n';
}
}
else if (bodyParam.bodyFormat === 'binary') {
if (go.isRequiredParameter(bodyParam)) {
text += `\t${emitSetBodyWithErrCheck(`req.SetBody(${bodyParam.name}, ${bodyParam.contentType})`, contentType)}`;
text += '\treturn req, nil\n';
}
else {
text += emitParamGroupCheck(bodyParam);
text += `\t${emitSetBodyWithErrCheck(`req.SetBody(${helpers.getParamName(bodyParam)}, ${bodyParam.contentType})`, contentType)}`;
text += '\treturn req, nil\n';
text += '\t}\n';
text += '\treturn req, nil\n';
}
}
else if (bodyParam.bodyFormat === 'Text') {
imports.add('strings');
imports.add('github.com/Azure/azure-sdk-for-go/sdk/azcore/streaming');
if (go.isRequiredParameter(bodyParam)) {
text += `\tbody := streaming.NopCloser(strings.NewReader(${bodyParam.name}))\n`;
text += `\t${emitSetBodyWithErrCheck(`req.SetBody(body, ${bodyParam.contentType})`, contentType)}`;
text += '\treturn req, nil\n';
}
else {
text += emitParamGroupCheck(bodyParam);
text += `\tbody := streaming.NopCloser(strings.NewReader(${helpers.getParamName(bodyParam)}))\n`;
text += `\t${emitSetBodyWithErrCheck(`req.SetBody(body, ${bodyParam.contentType})`, contentType)}`;
text += '\treturn req, nil\n';
text += '\t}\n';
text += '\treturn req, nil\n';
}
}
}
else if (partialBodyParams.length > 0) {
// partial body params are discrete params that are all fields within an internal struct.
// define and instantiate an instance of the wire type, using the values from each param.
text += '\tbody := struct {\n';
for (const partialBodyParam of partialBodyParams) {
text += `\t\t${capitalize(partialBodyParam.serializedName)} ${helpers.star(partialBodyParam)}${go.getTypeDeclaration(partialBodyParam.type)} \`${partialBodyParam.format.toLowerCase()}:"${partialBodyParam.serializedName}"\`\n`;
}
text += '\t}{\n';
// required params are emitted as initializers in the struct literal
for (const partialBodyParam of partialBodyParams) {
if (go.isRequiredParameter(partialBodyParam)) {
text += `\t\t${capitalize(partialBodyParam.serializedName)}: ${uncapitalize(partialBodyParam.name)},\n`;
}
}
text += '\t}\n';
// now populate any optional params from the options type
for (const partialBodyParam of partialBodyParams) {
if (!go.isRequiredParameter(partialBodyParam)) {
text += emitParamGroupCheck(partialBodyParam);
text += `\t\tbody.${capitalize(partialBodyParam.serializedName)} = options.${capitalize(partialBodyParam.name)}\n\t}\n`;
}
}
// TODO: spread params are JSON only https://github.com/Azure/autorest.go/issues/1455
text += '\treq.Raw().Header["Content-Type"] = []string{"application/json"}\n';
text += '\tif err := runtime.MarshalAsJSON(req, body); err != nil {\n\t\treturn nil, err\n\t}\n';
text += '\treturn req, nil\n';
}
else if (multipartBodyParams.length > 0) {
if (multipartBodyParams.length === 1 && multipartBodyParams[0].type.kind === 'model' && multipartBodyParams[0].type.annotations.multipartFormData) {
text += `\tformData, err := ${multipartBodyParams[0].name}.toMultipartFormData()\n`;
text += '\tif err != nil {\n\t\treturn nil, err\n\t}\n';
}
else {
text += '\tformData := map[string]any{}\n';
for (const param of multipartBodyParams) {
const setter = `formData["${param.name}"] = ${helpers.getParamName(param)}`;
if (go.isRequiredParameter(param)) {
text += `\t${setter}\n`;
}
else {
text += emitParamGroupCheck(param);
text += `\t${setter}\n\t}\n`;
}
}
}
text += '\tif err := runtime.SetMultipartFormData(req, formData); err != nil {\n\t\treturn nil, err\n\t}\n';
text += '\treturn req, nil\n';
}
else if (formBodyParams.length > 0) {
const emitFormData = function (param, setter) {
let formDataText = '';
if (go.isRequiredParameter(param)) {
formDataText = `\t${setter}\n`;
}
else {
formDataText = emitParamGroupCheck(param);
formDataText += `\t\t${setter}\n`;
formDataText += '\t}\n';
}
return formDataText;
};
imports.add('net/url');
imports.add('strings');
imports.add('github.com/Azure/azure-sdk-for-go/sdk/azcore/streaming');
text += '\tformData := url.Values{}\n';
// find all the form body params
for (const param of formBodyParams) {
const setter = `formData.Set("${param.formDataName}", ${helpers.formatParamValue(param, imports)})`;
text += emitFormData(param, setter);
}
text += '\tbody := streaming.NopCloser(strings.NewReader(formData.Encode()))\n';
text += `\t${emitSetBodyWithErrCheck('req.SetBody(body, "application/x-www-form-urlencoded")')}`;
text += '\treturn req, nil\n';
}
else {
text += '\treturn req, nil\n';
}
text += '}\n\n';
return text;
}
function emitClientSideDefault(param, csd, setterFormat, imports) {
const defaultVar = uncapitalize(param.name) + 'Default';
let text = `\t${defaultVar} := ${helpers.formatLiteralValue(csd.defaultValue, true)}\n`;
text += `\tif options != nil && options.${capitalize(param.name)} != nil {\n`;
text += `\t\t${defaultVar} = *options.${capitalize(param.name)}\n`;
text += '}\n';
let serializedName;
switch (param.kind) {
case 'headerCollectionParam':
case 'headerScalarParam':
serializedName = param.headerName;
break;
case 'queryCollectionParam':
case 'queryScalarParam':
serializedName = param.queryParameter;
break;
}
text += setterFormat(`"${serializedName}"`, helpers.formatValue(defaultVar, param.type, imports)) + '\n';
return text;
}
function getMediaFormat(type, mediaType, param) {
let marshaller = mediaType;
let format = '';
if (type.kind === 'encodedBytes') {
marshaller = 'ByteArray';
format = `, runtime.Base64${type.encoding}Format`;
}
return `${marshaller}(${param}${format})`;
}
function isArrayOfDateTimeForMarshalling(paramType) {
if (paramType.kind !== 'slice') {
return undefined;
}
if (paramType.elementType.kind !== 'time') {
return undefined;
}
switch (paramType.elementType.format) {
case 'dateType':
case 'dateTimeRFC1123':
case 'timeRFC3339':
case 'timeUnix':
return {
format: paramType.elementType.format,
elemByVal: paramType.elementTypeByValue
};
default:
// dateTimeRFC3339 uses the default marshaller
return undefined;
}
}
// returns true if the method requires a response handler.
// this is used to unmarshal the response body, parse response headers, or both.
function needsResponseHandler(method) {
return helpers.hasSchemaResponse(method) || method.responseEnvelope.headers.length > 0;
}
function generateResponseUnmarshaller(method, type, format, unmarshalTarget) {
let unmarshallerText = '';
const zeroValue = getZeroReturnValue(method, 'handler');
if (type.kind === 'time') {
// use the designated time type for unmarshalling
unmarshallerText += `\tvar aux *${type.format}\n`;
unmarshallerText += `\tif err := runtime.UnmarshalAs${format}(resp, &aux); err != nil {\n`;
unmarshallerText += `\t\treturn ${zeroValue}, err\n`;
unmarshallerText += '\t}\n';
unmarshallerText += `\tresult.${helpers.getResultFieldName(method)} = (*time.Time)(aux)\n`;
return unmarshallerText;
}
else if (isArrayOfDateTime(type)) {
// unmarshalling arrays of date/time is a little more involved
const timeInfo = isArrayOfDateTime(type);
let elementPtr = '*';
if (timeInfo?.elemByVal) {
elementPtr = '';
}
unmarshallerText += `\tvar aux []${elementPtr}${timeInfo?.format}\n`;
unmarshallerText += `\tif err := runtime.UnmarshalAs${format}(resp, &aux); err != nil {\n`;
unmarshallerText += `\t\treturn ${zeroValue}, err\n`;
unmarshallerText += '\t}\n';
unmarshallerText += `\tcp := make([]${elementPtr}time.Time, len(aux))\n`;
unmarshallerText += '\tfor i := 0; i < len(aux); i++ {\n