UNPKG

@autorest/go

Version:
1,026 lines 64.3 kB
/*--------------------------------------------------------------------------------------------- * Copyright (c) Microsoft Corporation. All rights reserved. * Licensed under the MIT License. See License.txt in the project root for license information. *--------------------------------------------------------------------------------------------*/ import * as go from '../../codemodel.go/src/index.js'; import { capitalize, comment, uncapitalize } from '@azure-tools/codegen'; import { values } from '@azure-tools/linq'; import * as helpers from './helpers.js'; import { ImportManager } from './imports.js'; import { CodegenError } from './errors.js'; // represents the generated content for an operation group export class OperationGroupContent { name; content; constructor(name, content) { this.name = name; this.content = content; } } // Creates the content for all <operation>.go files export async function generateOperations(codeModel) { // generate protocol operations const operations = new Array(); if (codeModel.clients.length === 0) { return operations; } const azureARM = codeModel.type === 'azure-arm'; for (const client of codeModel.clients) { // the list of packages to import const imports = new ImportManager(); if (client.methods.length > 0) { // add standard imports for clients with methods. // clients that are purely hierarchical (i.e. having no APIs) won't need them. imports.add('net/http'); imports.add('github.com/Azure/azure-sdk-for-go/sdk/azcore/policy'); imports.add('github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime'); } let clientPkg = 'azcore'; if (azureARM) { clientPkg = 'arm'; imports.add('github.com/Azure/azure-sdk-for-go/sdk/azcore/arm'); client.constructors.push(createARMClientConstructor(client, imports)); } else { imports.add('github.com/Azure/azure-sdk-for-go/sdk/azcore'); } // generate client type let clientText = helpers.formatDocComment(client.docs); clientText += '// Don\'t use this type directly, use '; if (client.constructors.length === 1) { clientText += `${client.constructors[0].name}() instead.\n`; } else if (client.parent) { // find the accessor method let accessorMethod; for (const clientAccessor of client.parent.clientAccessors) { if (clientAccessor.subClient === client) { accessorMethod = clientAccessor.name; break; } } if (!accessorMethod) { throw new CodegenError('InternalError', `didn't find accessor method for client ${client.name} on parent client ${client.parent.name}`); } clientText += `[${client.parent.name}.${accessorMethod}] instead.\n`; } else { clientText += 'a constructor function instead.\n'; } clientText += `type ${client.name} struct {\n`; clientText += `\tinternal *${clientPkg}.Client\n`; // check for any optional host params const optionalParams = new Array(); const isParamPointer = function (param) { // for client params, only optional and flag types are passed by pointer return param.style === 'flag' || param.style === 'optional'; }; // now emit any client params (non parameterized host params case) if (client.parameters.length > 0) { const addedGroups = new Set(); for (const clientParam of values(client.parameters)) { if (go.isLiteralParameter(clientParam)) { continue; } if (clientParam.group) { if (!addedGroups.has(clientParam.group.groupName)) { clientText += `\t${uncapitalize(clientParam.group.groupName)} ${!isParamPointer(clientParam) ? '' : '*'}${clientParam.group.groupName}\n`; addedGroups.add(clientParam.group.groupName); } continue; } clientText += `\t${clientParam.name} `; if (!isParamPointer(clientParam)) { clientText += `${go.getTypeDeclaration(clientParam.type)}\n`; } else { clientText += `${helpers.formatParameterTypeName(clientParam)}\n`; } if (!go.isRequiredParameter(clientParam)) { optionalParams.push(clientParam); } } } // end of client definition clientText += '}\n\n'; if (azureARM && optionalParams.length > 0) { throw new CodegenError('UnsupportedTsp', 'optional client parameters for ARM is not supported'); } // generate client constructors clientText += generateConstructors(azureARM, client, imports); // generate client accessors and operations let opText = ''; for (const clientAccessor of client.clientAccessors) { opText += `// ${clientAccessor.name} creates a new instance of [${clientAccessor.subClient.name}].\n`; opText += `func (client *${client.name}) ${clientAccessor.name}() *${clientAccessor.subClient.name} {\n`; opText += `\treturn &${clientAccessor.subClient.name}{\n`; opText += '\t\tinternal: client.internal,\n'; // propagate all client params for (const param of client.parameters) { opText += `\t\t${param.name}: client.${param.name},\n`; } opText += '\t}\n}\n\n'; } const nextPageMethods = new Array(); for (const method of client.methods) { // protocol creation can add imports to the list so // it must be done before the imports are written out if (go.isLROMethod(method)) { // generate Begin method opText += generateLROBeginMethod(client, method, imports, codeModel.options.injectSpans, codeModel.options.generateFakes); } opText += generateOperation(client, method, imports, codeModel.options.injectSpans, codeModel.options.generateFakes); opText += createProtocolRequest(azureARM, client, method, imports); if (method.kind !== 'lroMethod') { // LRO responses are handled elsewhere, with the exception of pageable LROs opText += createProtocolResponse(client, method, imports); } if ((method.kind === 'lroPageableMethod' || method.kind === 'pageableMethod') && method.nextPageMethod && !nextPageMethods.includes(method.nextPageMethod)) { // track the next page methods to generate as multiple operations can use the same next page operation nextPageMethods.push(method.nextPageMethod); } } for (const method of nextPageMethods) { opText += createProtocolRequest(azureARM, client, method, imports); } // stitch it all together let text = helpers.contentPreamble(codeModel); text += imports.text(); text += clientText; text += opText; operations.push(new OperationGroupContent(client.name, text)); } return operations; } // generates all modeled client constructors function generateConstructors(azureARM, client, imports) { if (client.constructors.length === 0) { return ''; } let ctorText = ''; for (const constructor of client.constructors) { const ctorParams = new Array(); const paramDocs = new Array(); constructor.parameters.sort(helpers.sortParametersByRequired); for (const ctorParam of constructor.parameters) { imports.addImportForType(ctorParam.type); ctorParams.push(`${ctorParam.name} ${helpers.formatParameterTypeName(ctorParam)}`); if (ctorParam.docs.summary || ctorParam.docs.description) { paramDocs.push(helpers.formatCommentAsBulletItem(ctorParam.name, ctorParam.docs)); } } // add client options last ctorParams.push(`${client.options.name} ${helpers.formatParameterTypeName(client.options)}`); paramDocs.push(helpers.formatCommentAsBulletItem(client.options.name, client.options.docs)); ctorText += `// ${constructor.name} creates a new instance of ${client.name} with the specified values.\n`; for (const doc of paramDocs) { ctorText += doc; } ctorText += `func ${constructor.name}(${ctorParams.join(', ')}) (*${client.name}, error) {\n`; let clientType = 'azcore'; if (azureARM) { clientType = 'arm'; } ctorText += `\tcl, err := ${clientType}.NewClient(moduleName, moduleVersion, credential, options)\n`; ctorText += '\tif err != nil {\n'; ctorText += '\t\treturn nil, err\n'; ctorText += '\t}\n'; // construct client literal ctorText += `\tclient := &${client.name}{\n`; for (const parameter of values(client.parameters)) { // each client field will have a matching parameter with the same name ctorText += `\t\t${parameter.name}: ${parameter.name},\n`; } ctorText += '\tinternal: cl,\n'; ctorText += '\t}\n'; ctorText += '\treturn client, nil\n'; ctorText += '}\n\n'; } return ctorText; } // creates a modeled constructor for an ARM client function createARMClientConstructor(client, imports) { imports.add('github.com/Azure/azure-sdk-for-go/sdk/azcore/arm'); const ctor = new go.Constructor(`New${client.name}`); // add any modeled parameter first, which should only be the subscriptionID, then add TokenCredential for (const param of client.parameters) { ctor.parameters.push(param); } const tokenCredParam = new go.Parameter('credential', new go.QualifiedType('TokenCredential', 'github.com/Azure/azure-sdk-for-go/sdk/azcore'), 'required', true, 'client'); tokenCredParam.docs.summary = 'used to authorize requests. Usually a credential from azidentity.'; ctor.parameters.push(tokenCredParam); return ctor; } // use this to generate the code that will help process values returned in response headers function formatHeaderResponseValue(headerResp, imports, respObj, zeroResp) { // dictionaries are handled slightly different so we do that first if (headerResp.kind === 'headerMapResponse') { imports.add('github.com/Azure/azure-sdk-for-go/sdk/azcore/to'); imports.add('strings'); const headerPrefix = headerResp.headerName; let text = '\tfor hh := range resp.Header {\n'; text += `\t\tif len(hh) > len("${headerPrefix}") && strings.EqualFold(hh[:len("${headerPrefix}")], "${headerPrefix}") {\n`; text += `\t\t\tif ${respObj}.${headerResp.fieldName} == nil {\n`; text += `\t\t\t\t${respObj}.${headerResp.fieldName} = map[string]*string{}\n`; text += '\t\t\t}\n'; text += `\t\t\t${respObj}.${headerResp.fieldName}[hh[len("${headerPrefix}"):]] = to.Ptr(resp.Header.Get(hh))\n`; text += '\t\t}\n'; text += '\t}\n'; return text; } let text = `\tif val := resp.Header.Get("${headerResp.headerName}"); val != "" {\n`; let name = uncapitalize(headerResp.fieldName); let byRef = '&'; switch (headerResp.type.kind) { case 'constant': text += `\t\t${respObj}.${headerResp.fieldName} = (*${headerResp.type.name})(&val)\n`; text += '\t}\n'; return text; case 'encodedBytes': // a base-64 encoded value in string format imports.add('encoding/base64'); text += `\t\t${name}, err := base64.${helpers.formatBytesEncoding(headerResp.type.encoding)}Encoding.DecodeString(val)\n`; byRef = ''; break; case 'literal': text += `\t\t${respObj}.${headerResp.fieldName} = &val\n`; text += '\t}\n'; return text; case 'scalar': imports.add('strconv'); switch (headerResp.type.type) { case 'bool': text += `\t\t${name}, err := strconv.ParseBool(val)\n`; break; case 'float32': text += `\t\t${name}32, err := strconv.ParseFloat(val, 32)\n`; text += `\t\t${name} := float32(${name}32)\n`; break; case 'float64': text += `\t\t${name}, err := strconv.ParseFloat(val, 64)\n`; break; case 'int32': text += `\t\t${name}32, err := strconv.ParseInt(val, 10, 32)\n`; text += `\t\t${name} := int32(${name}32)\n`; break; case 'int64': text += `\t\t${name}, err := strconv.ParseInt(val, 10, 64)\n`; break; default: throw new CodegenError('InternalError', `unhandled scalar type ${headerResp.type.type}`); } break; case 'string': text += `\t\t${respObj}.${headerResp.fieldName} = &val\n`; text += '\t}\n'; return text; case 'time': imports.add('time'); switch (headerResp.type.format) { case 'dateTimeRFC1123': case 'dateTimeRFC3339': text += `\t\t${name}, err := time.Parse(${headerResp.type.format === 'dateTimeRFC1123' ? helpers.datetimeRFC1123Format : helpers.datetimeRFC3339Format}, val)\n`; break; case 'dateType': text += `\t\t${name}, err := time.Parse("${helpers.dateFormat}", val)\n`; break; case 'timeRFC3339': text += `\t\t${name}, err := time.Parse("${helpers.timeRFC3339Format}", val)\n`; break; case 'timeUnix': imports.add('strconv'); imports.add('github.com/Azure/azure-sdk-for-go/sdk/azcore/to'); text += '\t\tsec, err := strconv.ParseInt(val, 10, 64)\n'; name = 'to.Ptr(time.Unix(sec, 0))'; byRef = ''; break; } } // NOTE: only cases that required parsing will fall through to here text += '\t\tif err != nil {\n'; text += `\t\t\treturn ${zeroResp}, err\n`; text += '\t\t}\n'; text += `\t\t${respObj}.${headerResp.fieldName} = ${byRef}${name}\n`; text += '\t}\n'; return text; } function getZeroReturnValue(method, apiType) { let returnType = `${method.responseEnvelope.name}{}`; if (go.isLROMethod(method)) { if (apiType === 'api' || apiType === 'op') { // the api returns a *Poller[T] // the operation returns an *http.Response returnType = 'nil'; } } return returnType; } // Helper function to generate nil checks for a dotted path function generateNilChecks(path, prefix = 'page') { const segments = path.split('.'); const checks = []; for (let i = 0; i < segments.length; i++) { const currentPath = [prefix, ...segments.slice(0, i + 1)].join('.'); checks.push(`${currentPath} != nil`); } return checks.join(' && '); } function emitPagerDefinition(client, method, imports, injectSpans, generateFakes) { imports.add('context'); let text = `runtime.NewPager(runtime.PagingHandler[${method.responseEnvelope.name}]{\n`; text += `\t\tMore: func(page ${method.responseEnvelope.name}) bool {\n`; // there is no advancer for single-page pagers if (method.nextLinkName) { const nilChecks = generateNilChecks(method.nextLinkName); text += `\t\t\treturn ${nilChecks} && len(*page.${method.nextLinkName}) > 0\n`; text += '\t\t},\n'; } else { text += '\t\t\treturn false\n'; text += '\t\t},\n'; } text += `\t\tFetcher: func(ctx context.Context, page *${method.responseEnvelope.name}) (${method.responseEnvelope.name}, error) {\n`; const reqParams = helpers.getCreateRequestParameters(method); if (generateFakes) { text += `\t\tctx = context.WithValue(ctx, runtime.CtxAPINameKey{}, "${client.name}.${fixUpMethodName(method)}")\n`; } if (method.nextLinkName) { let nextLinkVar; if (method.kind === 'pageableMethod') { text += '\t\t\tnextLink := ""\n'; nextLinkVar = 'nextLink'; text += '\t\t\tif page != nil {\n'; text += `\t\t\t\tnextLink = *page.${method.nextLinkName}\n\t\t\t}\n`; } else { nextLinkVar = `*page.${method.nextLinkName}`; } text += `\t\t\tresp, err := runtime.FetcherForNextLink(ctx, client.internal.Pipeline(), ${nextLinkVar}, func(ctx context.Context) (*policy.Request, error) {\n`; text += `\t\t\t\treturn client.${method.naming.requestMethod}(${reqParams})\n\t\t\t}, `; // nextPageMethod might be absent in some cases, see https://github.com/Azure/autorest/issues/4393 if (method.nextPageMethod) { const nextOpParams = helpers.getCreateRequestParametersSig(method.nextPageMethod).split(','); // keep the parameter names from the name/type tuples and find nextLink param for (let i = 0; i < nextOpParams.length; ++i) { const paramName = nextOpParams[i].trim().split(' ')[0]; const paramType = nextOpParams[i].trim().split(' ')[1]; if (paramName.startsWith('next') && paramType === 'string') { nextOpParams[i] = 'encodedNextLink'; } else { nextOpParams[i] = paramName; } } // add a definition for the nextReq func that uses the nextLinkOperation text += '&runtime.FetcherForNextLinkOptions{\n\t\t\t\tNextReq: func(ctx context.Context, encodedNextLink string) (*policy.Request, error) {\n'; text += `\t\t\t\t\treturn client.${method.nextPageMethod.name}(${nextOpParams.join(', ')})\n\t\t\t\t},\n\t\t\t})\n`; } else { text += 'nil)\n'; } text += `\t\t\tif err != nil {\n\t\t\t\treturn ${method.responseEnvelope.name}{}, err\n\t\t\t}\n`; text += `\t\t\treturn client.${method.naming.responseMethod}(resp)\n`; text += '\t\t\t},\n'; } else { // this is the singular page case, no fetcher helper required text += `\t\t\treq, err := client.${method.naming.requestMethod}(${reqParams})\n`; text += '\t\t\tif err != nil {\n'; text += `\t\t\t\treturn ${method.responseEnvelope.name}{}, err\n`; text += '\t\t\t}\n'; text += '\t\t\tresp, err := client.internal.Pipeline().Do(req)\n'; text += '\t\t\tif err != nil {\n'; text += `\t\t\t\treturn ${method.responseEnvelope.name}{}, err\n`; text += '\t\t\t}\n'; text += '\t\t\tif !runtime.HasStatusCode(resp, http.StatusOK) {\n'; text += `\t\t\t\treturn ${method.responseEnvelope.name}{}, runtime.NewResponseError(resp)\n`; text += '\t\t\t}\n'; text += `\t\t\treturn client.${method.naming.responseMethod}(resp)\n`; text += '\t\t},\n'; } if (injectSpans) { text += '\t\tTracer: client.internal.Tracer(),\n'; } text += '\t})\n'; return text; } function genApiVersionDoc(apiVersions) { if (apiVersions.length === 0) { return ''; } return `//\n// Generated from API version ${apiVersions.join(', ')}\n`; } function genRespErrorDoc(method) { if (!(method.responseEnvelope.result?.kind === 'headAsBooleanResult') && !go.isPageableMethod(method)) { // when head-as-boolean is enabled, no error is returned for 4xx status codes. // pager constructors don't return an error return '// If the operation fails it returns an *azcore.ResponseError type.\n'; } return ''; } function generateOperation(client, method, imports, injectSpans, generateFakes) { const params = getAPIParametersSig(method, imports); const returns = generateReturnsInfo(method, 'op'); let methodName = method.name; if (method.kind === 'pageableMethod') { methodName = fixUpMethodName(method); } let text = ''; const respErrDoc = genRespErrorDoc(method); const apiVerDoc = genApiVersionDoc(method.apiVersions); if (method.docs.summary || method.docs.description) { text += helpers.formatDocCommentWithPrefix(methodName, method.docs); } else if (respErrDoc.length > 0 || apiVerDoc.length > 0) { // if the method has no doc comment but we're adding other // doc comments, add an empty method name comment. this preserves // existing behavior and makes the docs look better overall. text += `// ${methodName} -\n`; } text += respErrDoc; text += apiVerDoc; if (go.isLROMethod(method)) { methodName = method.naming.internalMethod; } else { for (const param of values(helpers.getMethodParameters(method))) { text += helpers.formatCommentAsBulletItem(param.name, param.docs); } } text += `func (client *${client.name}) ${methodName}(${params}) (${returns.join(', ')}) {\n`; const reqParams = helpers.getCreateRequestParameters(method); if (method.kind === 'pageableMethod') { text += '\treturn '; text += emitPagerDefinition(client, method, imports, injectSpans, generateFakes); text += '}\n\n'; return text; } text += '\tvar err error\n'; let operationName = `"${client.name}.${fixUpMethodName(method)}"`; if (generateFakes && injectSpans) { text += `\tconst operationName = ${operationName}\n`; operationName = 'operationName'; } if (generateFakes) { text += `\tctx = context.WithValue(ctx, runtime.CtxAPINameKey{}, ${operationName})\n`; } if (injectSpans) { text += `\tctx, endSpan := runtime.StartSpan(ctx, ${operationName}, client.internal.Tracer(), nil)\n`; text += '\tdefer func() { endSpan(err) }()\n'; } const zeroResp = getZeroReturnValue(method, 'op'); text += `\treq, err := client.${method.naming.requestMethod}(${reqParams})\n`; text += '\tif err != nil {\n'; text += `\t\treturn ${zeroResp}, err\n`; text += '\t}\n'; text += '\thttpResp, err := client.internal.Pipeline().Do(req)\n'; text += '\tif err != nil {\n'; text += `\t\treturn ${zeroResp}, err\n`; text += '\t}\n'; text += `\tif !runtime.HasStatusCode(httpResp, ${helpers.formatStatusCodes(method.httpStatusCodes)}) {\n`; text += '\t\terr = runtime.NewResponseError(httpResp)\n'; text += `\t\treturn ${zeroResp}, err\n`; text += '\t}\n'; // HAB with headers response is handled in protocol responder if (method.responseEnvelope.result?.kind === 'headAsBooleanResult' && method.responseEnvelope.headers.length === 0) { text += `\treturn ${method.responseEnvelope.name}{${method.responseEnvelope.result.fieldName}: httpResp.StatusCode >= 200 && httpResp.StatusCode < 300}, nil\n`; } else { if (go.isLROMethod(method)) { text += '\treturn httpResp, nil\n'; } else if (needsResponseHandler(method)) { // also cheating here as at present the only param to the responder is an http.Response text += `\tresp, err := client.${method.naming.responseMethod}(httpResp)\n`; text += '\treturn resp, err\n'; } else if (method.responseEnvelope.result?.kind === 'binaryResult') { text += `\treturn ${method.responseEnvelope.name}{${method.responseEnvelope.result.fieldName}: httpResp.Body}, nil\n`; } else { text += `\treturn ${method.responseEnvelope.name}{}, nil\n`; } } text += '}\n\n'; return text; } function createProtocolRequest(azureARM, client, method, imports) { let name = method.name; if (method.kind !== 'nextPageMethod') { name = method.naming.requestMethod; } for (const param of values(method.parameters)) { if (param.location !== 'method' || !go.isRequiredParameter(param)) { continue; } imports.addImportForType(param.type); } const returns = ['*policy.Request', 'error']; let text = `${comment(name, '// ')} creates the ${method.name} request.\n`; text += `func (client *${client.name}) ${name}(${helpers.getCreateRequestParametersSig(method)}) (${returns.join(', ')}) {\n`; const hostParams = new Array(); for (const parameter of client.parameters) { if (parameter.kind === 'uriParam') { hostParams.push(parameter); } } let hostParam; if (azureARM) { hostParam = 'client.internal.Endpoint()'; } else if (client.templatedHost) { imports.add('strings'); // we have a templated host text += `\thost := "${client.templatedHost}"\n`; // get all the host params on the client for (const hostParam of hostParams) { text += `\thost = strings.ReplaceAll(host, "{${hostParam.uriPathSegment}}", ${helpers.formatValue(`client.${hostParam.name}`, hostParam.type, imports)})\n`; } // check for any method local host params for (const param of values(method.parameters)) { if (param.location === 'method' && param.kind === 'uriParam') { text += `\thost = strings.ReplaceAll(host, "{${param.uriPathSegment}}", ${helpers.formatValue(helpers.getParamName(param), param.type, imports)})\n`; } } hostParam = 'host'; } else if (hostParams.length === 1) { // simple parameterized host case hostParam = 'client.' + hostParams[0].name; } else { throw new CodegenError('InternalError', `no host or endpoint defined for method ${client.name}.${method.name}`); } const methodParamGroups = helpers.getMethodParamGroups(method); const hasPathParams = methodParamGroups.pathParams.length > 0; // storage needs the client.u to be the source-of-truth for the full path. // however, swagger requires that all operations specify a path, which is at odds with storage. // to work around this, storage specifies x-ms-path paths with path params but doesn't // actually reference the path params (i.e. no params with which to replace the tokens). // so, if a path contains tokens but there are no path params, skip emitting the path. const pathStr = method.httpPath; const pathContainsParms = pathStr.includes('{'); if (hasPathParams || (!pathContainsParms && pathStr.length > 1)) { // there are path params, or the path doesn't contain tokens and is not "/" so emit it text += `\turlPath := "${method.httpPath}"\n`; hostParam = `runtime.JoinPaths(${hostParam}, urlPath)`; } if (hasPathParams) { // swagger defines path params, emit path and replace tokens imports.add('strings'); // replace path parameters for (const pp of methodParamGroups.pathParams) { // emit check to ensure path param isn't an empty string. we only need // to do this for params that have an underlying type of string. const choiceIsString = function (type) { return type.kind === 'constant' && type.type === 'string'; }; // TODO: https://github.com/Azure/autorest.go/issues/1593 if (pp.kind === 'pathScalarParam' && ((pp.type.kind === 'string' || choiceIsString(pp.type)) && pp.isEncoded)) { const paramName = helpers.getParamName(pp); imports.add('errors'); text += `\tif ${paramName} == "" {\n`; text += `\t\treturn nil, errors.New("parameter ${paramName} cannot be empty")\n`; text += '\t}\n'; } let paramValue = helpers.formatParamValue(pp, imports); if (pp.isEncoded) { imports.add('net/url'); paramValue = `url.PathEscape(${helpers.formatParamValue(pp, imports)})`; } text += `\turlPath = strings.ReplaceAll(urlPath, "{${pp.pathSegment}}", ${paramValue})\n`; } } text += `\treq, err := runtime.NewRequest(ctx, http.Method${capitalize(method.httpMethod)}, ${hostParam})\n`; text += '\tif err != nil {\n'; text += '\t\treturn nil, err\n'; text += '\t}\n'; // helper to build nil checks for param groups const emitParamGroupCheck = function (param) { if (!param.group) { throw new CodegenError('InternalError', `emitParamGroupCheck called for ungrouped parameter ${param.name}`); } let client = ''; if (param.location === 'client') { client = 'client.'; } const paramGroupName = uncapitalize(param.group.name); let optionalParamGroupCheck = `${client}${paramGroupName} != nil && `; if (param.group.required) { optionalParamGroupCheck = ''; } return `\tif ${optionalParamGroupCheck}${client}${paramGroupName}.${capitalize(param.name)} != nil {\n`; }; // add query parameters const encodedParams = methodParamGroups.encodedQueryParams; const unencodedParams = methodParamGroups.unencodedQueryParams; const emitQueryParam = function (qp, setter) { let qpText = ''; if (qp.location === 'method' && go.isClientSideDefault(qp.style)) { qpText = emitClientSideDefault(qp, qp.style, (name, val) => { return `\treqQP.Set(${name}, ${val})`; }, imports); } else if (go.isRequiredParameter(qp) || go.isLiteralParameter(qp) || (qp.location === 'client' && go.isClientSideDefault(qp.style))) { qpText = `\t${setter}\n`; } else if (qp.location === 'client' && !qp.group) { // global optional param qpText = `\tif client.${qp.name} != nil {\n`; qpText += `\t\t${setter}\n`; qpText += '\t}\n'; } else { qpText = emitParamGroupCheck(qp); qpText += `\t\t${setter}\n`; qpText += '\t}\n'; } return qpText; }; // emit encoded params first if (encodedParams.length > 0) { text += '\treqQP := req.Raw().URL.Query()\n'; for (const qp of values(encodedParams.sort((a, b) => { return helpers.sortAscending(a.queryParameter, b.queryParameter); }))) { let setter; if (qp.kind === 'queryCollectionParam' && qp.collectionFormat === 'multi') { setter = `\tfor _, qv := range ${helpers.getParamName(qp)} {\n`; // emit a type conversion for the qv based on the array's element type let queryVal; const arrayQP = qp.type; switch (arrayQP.elementType.kind) { case 'constant': switch (arrayQP.elementType.type) { case 'string': queryVal = 'string(qv)'; break; default: imports.add('fmt'); queryVal = 'fmt.Sprintf("%d", qv)'; } break; case 'string': queryVal = 'qv'; break; default: imports.add('fmt'); queryVal = 'fmt.Sprintf("%v", qv)'; } setter += `\t\treqQP.Add("${qp.queryParameter}", ${queryVal})\n`; setter += '\t}'; } else { // cannot initialize setter to this value as helpers.formatParamValue() can change imports setter = `reqQP.Set("${qp.queryParameter}", ${helpers.formatParamValue(qp, imports)})`; } text += emitQueryParam(qp, setter); } text += '\treq.Raw().URL.RawQuery = reqQP.Encode()\n'; } // tack on any unencoded params to the end if (unencodedParams.length > 0) { if (encodedParams.length > 0) { text += '\tunencodedParams := []string{req.Raw().URL.RawQuery}\n'; } else { text += '\tunencodedParams := []string{}\n'; } for (const qp of values(unencodedParams.sort((a, b) => { return helpers.sortAscending(a.queryParameter, b.queryParameter); }))) { let setter; if (qp.kind === 'queryCollectionParam' && qp.collectionFormat === 'multi') { setter = `\tfor _, qv := range ${helpers.getParamName(qp)} {\n`; setter += `\t\tunencodedParams = append(unencodedParams, "${qp.queryParameter}="+qv)\n`; setter += '\t}'; } else { setter = `unencodedParams = append(unencodedParams, "${qp.queryParameter}="+${helpers.formatParamValue(qp, imports)})`; } text += emitQueryParam(qp, setter); } imports.add('strings'); text += '\treq.Raw().URL.RawQuery = strings.Join(unencodedParams, "&")\n'; } if (method.kind !== 'nextPageMethod' && method.responseEnvelope.result?.kind === 'binaryResult') { // skip auto-body downloading for binary stream responses text += '\truntime.SkipBodyDownload(req)\n'; } // add specific request headers const emitHeaderSet = function (headerParam, prefix) { if (headerParam.kind === 'headerMapParam') { let headerText = `${prefix}for k, v := range ${helpers.getParamName(headerParam)} {\n`; headerText += `${prefix}\tif v != nil {\n`; headerText += `${prefix}\t\treq.Raw().Header["${headerParam.headerName}"+k] = []string{*v}\n`; headerText += `${prefix}}\n`; headerText += `${prefix}}\n`; return headerText; } else if (headerParam.location === 'method' && go.isClientSideDefault(headerParam.style)) { return emitClientSideDefault(headerParam, headerParam.style, (name, val) => { return `${prefix}req.Raw().Header[${name}] = []string{${val}}`; }, imports); } else { return `${prefix}req.Raw().Header["${headerParam.headerName}"] = []string{${helpers.formatParamValue(headerParam, imports)}}\n`; } }; let contentType; for (const param of methodParamGroups.headerParams.sort((a, b) => { return helpers.sortAscending(a.headerName, b.headerName); })) { if (param.headerName.match(/^content-type$/)) { // canonicalize content-type as req.SetBody checks for it via its canonicalized name :( param.headerName = 'Content-Type'; } if (param.headerName === 'Content-Type' && param.style === 'literal') { // the content-type header will be set as part of emitSetBodyWithErrCheck // to handle cases where the body param is optional. we don't want to set // the content-type if the body is nil. // we do it like this as tsp specifies content-type while swagger does not. contentType = helpers.formatParamValue(param, imports); } else if (go.isRequiredParameter(param) || go.isLiteralParameter(param) || go.isClientSideDefault(param.style)) { text += emitHeaderSet(param, '\t'); } else if (param.location === 'client' && !param.group) { // global optional param text += `\tif client.${param.name} != nil {\n`; text += emitHeaderSet(param, '\t'); text += '\t}\n'; } else { text += emitParamGroupCheck(param); text += emitHeaderSet(param, '\t\t'); text += '\t}\n'; } } // note that these are mutually exclusive const bodyParam = methodParamGroups.bodyParam; const formBodyParams = methodParamGroups.formBodyParams; const multipartBodyParams = methodParamGroups.multipartBodyParams; const partialBodyParams = methodParamGroups.partialBodyParams; const emitSetBodyWithErrCheck = function (setBodyParam, contentType) { let content = `if err := ${setBodyParam}; err != nil {\n\treturn nil, err\n}\n;`; if (contentType) { content = `req.Raw().Header["Content-Type"] = []string{${contentType}}\n` + content; } return content; }; if (bodyParam) { if (bodyParam.bodyFormat === 'JSON' || bodyParam.bodyFormat === 'XML') { // default to the body param name let body = helpers.getParamName(bodyParam); if (bodyParam.type.kind === 'literal') { // if the value is constant, embed it directly body = helpers.formatLiteralValue(bodyParam.type, true); } else if (bodyParam.bodyFormat === 'XML' && bodyParam.type.kind === 'slice') { // for XML payloads, create a wrapper type if the payload is an array imports.add('encoding/xml'); text += '\ttype wrapper struct {\n'; let tagName = go.getTypeDeclaration(bodyParam.type); if (bodyParam.xml?.name) { tagName = bodyParam.xml.name; } text += `\t\tXMLName xml.Name \`xml:"${tagName}"\`\n`; const fieldName = capitalize(bodyParam.name); let tag = go.getTypeDeclaration(bodyParam.type.elementType); if (bodyParam.type.elementType.kind === 'model' && bodyParam.type.elementType.xml?.name) { tag = bodyParam.type.elementType.xml.name; } text += `\t\t${fieldName} *${go.getTypeDeclaration(bodyParam.type)} \`xml:"${tag}"\`\n`; text += '\t}\n'; let addr = '&'; if (!go.isRequiredParameter(bodyParam) && !bodyParam.byValue) { addr = ''; } body = `wrapper{${fieldName}: ${addr}${body}}`; } else if (bodyParam.type.kind === 'time' && bodyParam.type.format !== 'dateTimeRFC3339') { // wrap the body in the internal time type // no need for dateTimeRFC3339 as the JSON marshaler defaults to that. body = `${bodyParam.type.format}(${body})`; } else if (isArrayOfDateTimeForMarshalling(bodyParam.type)) { const timeInfo = isArrayOfDateTimeForMarshalling(bodyParam.type); let elementPtr = '*'; if (timeInfo?.elemByVal) { elementPtr = ''; } text += `\taux := make([]${elementPtr}${timeInfo?.format}, len(${body}))\n`; text += `\tfor i := 0; i < len(${body}); i++ {\n`; text += `\t\taux[i] = (${elementPtr}${timeInfo?.format})(${body}[i])\n`; text += '\t}\n'; body = 'aux'; } else if (isMapOfDateTime(bodyParam.type)) { const timeType = isMapOfDateTime(bodyParam.type); text += `\taux := map[string]*${timeType}{}\n`; text += `\tfor k, v := range ${body} {\n`; text += `\t\taux[k] = (*${timeType})(v)\n`; text += '\t}\n'; body = 'aux'; } let setBody = `runtime.MarshalAs${getMediaFormat(bodyParam.type, bodyParam.bodyFormat, `req, ${body}`)}`; if (bodyParam.type.kind === 'rawJSON') { imports.add('bytes'); imports.add('github.com/Azure/azure-sdk-for-go/sdk/azcore/streaming'); setBody = `req.SetBody(streaming.NopCloser(bytes.NewReader(${body})), "application/${bodyParam.bodyFormat.toLowerCase()}")`; } if (go.isRequiredParameter(bodyParam) || go.isLiteralParameter(bodyParam)) { text += `\t${emitSetBodyWithErrCheck(setBody, contentType)}`; text += '\treturn req, nil\n'; } else { text += emitParamGroupCheck(bodyParam); text += `\t${emitSetBodyWithErrCheck(setBody, contentType)}`; text += '\t\treturn req, nil\n'; text += '\t}\n'; text += '\treturn req, nil\n'; } } else if (bodyParam.bodyFormat === 'binary') { if (go.isRequiredParameter(bodyParam)) { text += `\t${emitSetBodyWithErrCheck(`req.SetBody(${bodyParam.name}, ${bodyParam.contentType})`, contentType)}`; text += '\treturn req, nil\n'; } else { text += emitParamGroupCheck(bodyParam); text += `\t${emitSetBodyWithErrCheck(`req.SetBody(${helpers.getParamName(bodyParam)}, ${bodyParam.contentType})`, contentType)}`; text += '\treturn req, nil\n'; text += '\t}\n'; text += '\treturn req, nil\n'; } } else if (bodyParam.bodyFormat === 'Text') { imports.add('strings'); imports.add('github.com/Azure/azure-sdk-for-go/sdk/azcore/streaming'); if (go.isRequiredParameter(bodyParam)) { text += `\tbody := streaming.NopCloser(strings.NewReader(${bodyParam.name}))\n`; text += `\t${emitSetBodyWithErrCheck(`req.SetBody(body, ${bodyParam.contentType})`, contentType)}`; text += '\treturn req, nil\n'; } else { text += emitParamGroupCheck(bodyParam); text += `\tbody := streaming.NopCloser(strings.NewReader(${helpers.getParamName(bodyParam)}))\n`; text += `\t${emitSetBodyWithErrCheck(`req.SetBody(body, ${bodyParam.contentType})`, contentType)}`; text += '\treturn req, nil\n'; text += '\t}\n'; text += '\treturn req, nil\n'; } } } else if (partialBodyParams.length > 0) { // partial body params are discrete params that are all fields within an internal struct. // define and instantiate an instance of the wire type, using the values from each param. text += '\tbody := struct {\n'; for (const partialBodyParam of partialBodyParams) { text += `\t\t${capitalize(partialBodyParam.serializedName)} ${helpers.star(partialBodyParam)}${go.getTypeDeclaration(partialBodyParam.type)} \`${partialBodyParam.format.toLowerCase()}:"${partialBodyParam.serializedName}"\`\n`; } text += '\t}{\n'; // required params are emitted as initializers in the struct literal for (const partialBodyParam of partialBodyParams) { if (go.isRequiredParameter(partialBodyParam)) { text += `\t\t${capitalize(partialBodyParam.serializedName)}: ${uncapitalize(partialBodyParam.name)},\n`; } } text += '\t}\n'; // now populate any optional params from the options type for (const partialBodyParam of partialBodyParams) { if (!go.isRequiredParameter(partialBodyParam)) { text += emitParamGroupCheck(partialBodyParam); text += `\t\tbody.${capitalize(partialBodyParam.serializedName)} = options.${capitalize(partialBodyParam.name)}\n\t}\n`; } } // TODO: spread params are JSON only https://github.com/Azure/autorest.go/issues/1455 text += '\treq.Raw().Header["Content-Type"] = []string{"application/json"}\n'; text += '\tif err := runtime.MarshalAsJSON(req, body); err != nil {\n\t\treturn nil, err\n\t}\n'; text += '\treturn req, nil\n'; } else if (multipartBodyParams.length > 0) { if (multipartBodyParams.length === 1 && multipartBodyParams[0].type.kind === 'model' && multipartBodyParams[0].type.annotations.multipartFormData) { text += `\tformData, err := ${multipartBodyParams[0].name}.toMultipartFormData()\n`; text += '\tif err != nil {\n\t\treturn nil, err\n\t}\n'; } else { text += '\tformData := map[string]any{}\n'; for (const param of multipartBodyParams) { const setter = `formData["${param.name}"] = ${helpers.getParamName(param)}`; if (go.isRequiredParameter(param)) { text += `\t${setter}\n`; } else { text += emitParamGroupCheck(param); text += `\t${setter}\n\t}\n`; } } } text += '\tif err := runtime.SetMultipartFormData(req, formData); err != nil {\n\t\treturn nil, err\n\t}\n'; text += '\treturn req, nil\n'; } else if (formBodyParams.length > 0) { const emitFormData = function (param, setter) { let formDataText = ''; if (go.isRequiredParameter(param)) { formDataText = `\t${setter}\n`; } else { formDataText = emitParamGroupCheck(param); formDataText += `\t\t${setter}\n`; formDataText += '\t}\n'; } return formDataText; }; imports.add('net/url'); imports.add('strings'); imports.add('github.com/Azure/azure-sdk-for-go/sdk/azcore/streaming'); text += '\tformData := url.Values{}\n'; // find all the form body params for (const param of formBodyParams) { const setter = `formData.Set("${param.formDataName}", ${helpers.formatParamValue(param, imports)})`; text += emitFormData(param, setter); } text += '\tbody := streaming.NopCloser(strings.NewReader(formData.Encode()))\n'; text += `\t${emitSetBodyWithErrCheck('req.SetBody(body, "application/x-www-form-urlencoded")')}`; text += '\treturn req, nil\n'; } else { text += '\treturn req, nil\n'; } text += '}\n\n'; return text; } function emitClientSideDefault(param, csd, setterFormat, imports) { const defaultVar = uncapitalize(param.name) + 'Default'; let text = `\t${defaultVar} := ${helpers.formatLiteralValue(csd.defaultValue, true)}\n`; text += `\tif options != nil && options.${capitalize(param.name)} != nil {\n`; text += `\t\t${defaultVar} = *options.${capitalize(param.name)}\n`; text += '}\n'; let serializedName; switch (param.kind) { case 'headerCollectionParam': case 'headerScalarParam': serializedName = param.headerName; break; case 'queryCollectionParam': case 'queryScalarParam': serializedName = param.queryParameter; break; } text += setterFormat(`"${serializedName}"`, helpers.formatValue(defaultVar, param.type, imports)) + '\n'; return text; } function getMediaFormat(type, mediaType, param) { let marshaller = mediaType; let format = ''; if (type.kind === 'encodedBytes') { marshaller = 'ByteArray'; format = `, runtime.Base64${type.encoding}Format`; } return `${marshaller}(${param}${format})`; } function isArrayOfDateTimeForMarshalling(paramType) { if (paramType.kind !== 'slice') { return undefined; } if (paramType.elementType.kind !== 'time') { return undefined; } switch (paramType.elementType.format) { case 'dateType': case 'dateTimeRFC1123': case 'timeRFC3339': case 'timeUnix': return { format: paramType.elementType.format, elemByVal: paramType.elementTypeByValue }; default: // dateTimeRFC3339 uses the default marshaller return undefined; } } // returns true if the method requires a response handler. // this is used to unmarshal the response body, parse response headers, or both. function needsResponseHandler(method) { return helpers.hasSchemaResponse(method) || method.responseEnvelope.headers.length > 0; } function generateResponseUnmarshaller(method, type, format, unmarshalTarget) { let unmarshallerText = ''; const zeroValue = getZeroReturnValue(method, 'handler'); if (type.kind === 'time') { // use the designated time type for unmarshalling unmarshallerText += `\tvar aux *${type.format}\n`; unmarshallerText += `\tif err := runtime.UnmarshalAs${format}(resp, &aux); err != nil {\n`; unmarshallerText += `\t\treturn ${zeroValue}, err\n`; unmarshallerText += '\t}\n'; unmarshallerText += `\tresult.${helpers.getResultFieldName(method)} = (*time.Time)(aux)\n`; return unmarshallerText; } else if (isArrayOfDateTime(type)) { // unmarshalling arrays of date/time is a little more involved const timeInfo = isArrayOfDateTime(type); let elementPtr = '*'; if (timeInfo?.elemByVal) { elementPtr = ''; } unmarshallerText += `\tvar aux []${elementPtr}${timeInfo?.format}\n`; unmarshallerText += `\tif err := runtime.UnmarshalAs${format}(resp, &aux); err != nil {\n`; unmarshallerText += `\t\treturn ${zeroValue}, err\n`; unmarshallerText += '\t}\n'; unmarshallerText += `\tcp := make([]${elementPtr}time.Time, len(aux))\n`; unmarshallerText += '\tfor i := 0; i < len(aux); i++ {\n