UNPKG

@autorest/go

Version:
1,011 lines (1,010 loc) 75 kB
/*--------------------------------------------------------------------------------------------- * Copyright (c) Microsoft Corporation. All rights reserved. * Licensed under the MIT License. See License.txt in the project root for license information. *--------------------------------------------------------------------------------------------*/ import * as go from '../../codemodel.go/src/index.js'; import { ensureNameCase } from '../../naming.go/src/naming.js'; import { capitalize, comment, uncapitalize } from '@azure-tools/codegen'; import { values } from '@azure-tools/linq'; import * as helpers from './helpers.js'; import { ImportManager } from './imports.js'; import { CodegenError } from './errors.js'; // represents the generated content for an operation group export class OperationGroupContent { name; content; constructor(name, content) { this.name = name; this.content = content; } } // Creates the content for all <operation>.go files export async function generateOperations(codeModel) { // generate protocol operations const operations = new Array(); if (codeModel.clients.length === 0) { return operations; } const azureARM = codeModel.type === 'azure-arm'; for (const client of codeModel.clients) { // the list of packages to import const imports = new ImportManager(); if (client.methods.length > 0) { // add standard imports for clients with methods. // clients that are purely hierarchical (i.e. having no APIs) won't need them. imports.add('net/http'); imports.add('github.com/Azure/azure-sdk-for-go/sdk/azcore/policy'); imports.add('github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime'); } imports.add(azureARM ? 'github.com/Azure/azure-sdk-for-go/sdk/azcore/arm' : 'github.com/Azure/azure-sdk-for-go/sdk/azcore'); // generate client type let clientText = helpers.formatDocComment(client.docs); clientText += '// Don\'t use this type directly, use '; if (client.instance?.kind === 'constructable' && client.instance.constructors.length === 1) { clientText += `${client.instance.constructors[0].name}() instead.\n`; } else if (client.parent) { // find the accessor method let accessorMethod; for (const clientAccessor of client.parent.clientAccessors) { if (clientAccessor.subClient === client) { accessorMethod = clientAccessor.name; break; } } if (!accessorMethod) { throw new CodegenError('InternalError', `didn't find accessor method for client ${client.name} on parent client ${client.parent.name}`); } clientText += `[${client.parent.name}.${accessorMethod}] instead.\n`; } else { clientText += 'a constructor function instead.\n'; } clientText += `type ${client.name} struct {\n`; clientText += `\tinternal *${azureARM ? 'arm' : 'azcore'}.Client\n`; // check for any optional host params const optionalParams = new Array(); const isParamPointer = function (param) { // for client params, only optional and flag types are passed by pointer return param.style === 'flag' || param.style === 'optional'; }; // now emit any client params (non parameterized host params case) if (client.parameters.length > 0) { const addedGroups = new Set(); for (const clientParam of values(client.parameters)) { if (go.isLiteralParameter(clientParam.style)) { continue; } if (clientParam.group) { if (!addedGroups.has(clientParam.group.groupName)) { clientText += `\t${uncapitalize(clientParam.group.groupName)} ${!isParamPointer(clientParam) ? '' : '*'}${clientParam.group.groupName}\n`; addedGroups.add(clientParam.group.groupName); } continue; } clientText += `\t${clientParam.name} `; if (!isParamPointer(clientParam)) { clientText += `${go.getTypeDeclaration(clientParam.type)}\n`; } else { clientText += `${helpers.formatParameterTypeName(clientParam)}\n`; } if (!go.isRequiredParameter(clientParam.style)) { optionalParams.push(clientParam); } } } // end of client definition clientText += '}\n\n'; clientText += generateConstructors(client, codeModel.type, imports); // generate client accessors and operations let opText = ''; for (const clientAccessor of client.clientAccessors) { opText += `// ${clientAccessor.name} creates a new instance of [${clientAccessor.subClient.name}].\n`; opText += `func (client *${client.name}) ${clientAccessor.name}() *${clientAccessor.subClient.name} {\n`; opText += `\treturn &${clientAccessor.subClient.name}{\n`; opText += '\t\tinternal: client.internal,\n'; // propagate all client params for (const param of client.parameters) { if (go.isLiteralParameter(param.style)) { continue; } opText += `\t\t${param.name}: client.${param.name},\n`; } opText += '\t}\n}\n\n'; } const nextPageMethods = new Array(); for (const method of client.methods) { // protocol creation can add imports to the list so // it must be done before the imports are written out if (go.isLROMethod(method)) { // generate Begin method opText += generateLROBeginMethod(method, imports, codeModel.options.injectSpans, codeModel.options.generateFakes); } opText += generateOperation(method, imports, codeModel.options.injectSpans, codeModel.options.generateFakes); opText += createProtocolRequest(azureARM, method, imports); if (method.kind !== 'lroMethod') { // LRO responses are handled elsewhere, with the exception of pageable LROs opText += createProtocolResponse(method, imports); } if ((method.kind === 'lroPageableMethod' || method.kind === 'pageableMethod') && method.nextPageMethod && !nextPageMethods.includes(method.nextPageMethod)) { // track the next page methods to generate as multiple operations can use the same next page operation nextPageMethods.push(method.nextPageMethod); } } for (const method of nextPageMethods) { opText += createProtocolRequest(azureARM, method, imports); } // stitch it all together let text = helpers.contentPreamble(codeModel); text += imports.text(); text += clientText; text += opText; operations.push(new OperationGroupContent(client.name, text)); } return operations; } /** * generates all modeled client constructors and client options types. * if there are no client constructors, the empty string is returned. * * @param client the client for which to generate constructors and the client options type * @param imports the import manager currently in scope * @returns the client constructor code or the empty string */ function generateConstructors(client, type, imports) { if (client.instance?.kind !== 'constructable') { return ''; } const clientOptions = client.instance.options; let ctorText = ''; if (clientOptions.kind === 'clientOptions') { // for non-ARM, the options type will always be a parameter group ctorText += `// ${clientOptions.name} contains the optional values for creating a [${client.name}].\n`; ctorText += `type ${clientOptions.name} struct {\n\tazcore.ClientOptions\n`; for (const param of clientOptions.parameters) { if (go.isAPIVersionParameter(param)) { // we use azcore.ClientOptions.APIVersion continue; } ctorText += helpers.formatDocCommentWithPrefix(ensureNameCase(param.name), param.docs); if (go.isClientSideDefault(param.style)) { if (!param.docs.description && !param.docs.summary) { ctorText += '\n'; } ctorText += `\t${comment(`The default value is ${helpers.formatLiteralValue(param.style.defaultValue, false)}`, '// ')}.\n`; } ctorText += `\t${ensureNameCase(param.name)} *${go.getTypeDeclaration(param.type)}\n`; } ctorText += '}\n\n'; } for (const constructor of client.instance.constructors) { const ctorParams = new Array(); const paramDocs = new Array(); // ctor params can also be present in the supplemental endpoint parameters const consolidatedCtorParams = new Array(); if (client.instance.endpoint) { consolidatedCtorParams.push(client.instance.endpoint.parameter); if (client.instance.endpoint.supplemental) { consolidatedCtorParams.push(...client.instance.endpoint.supplemental.parameters); } } for (const param of helpers.sortClientParameters(constructor.parameters, type)) { if (!consolidatedCtorParams.includes(param)) { consolidatedCtorParams.push(param); } } for (const ctorParam of consolidatedCtorParams) { if (!go.isRequiredParameter(ctorParam.style)) { // param is part of the options group continue; } imports.addImportForType(ctorParam.type); ctorParams.push(`${ctorParam.name} ${helpers.formatParameterTypeName(ctorParam)}`); if (ctorParam.docs.summary || ctorParam.docs.description) { paramDocs.push(helpers.formatCommentAsBulletItem(ctorParam.name, ctorParam.docs)); } } const emitProlog = function (optionsTypeName, tokenAuth, plOpts) { imports.add('github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime'); let bodyText = `\tif options == nil {\n\t\toptions = &${optionsTypeName}{}\n\t}\n`; let apiVersionConfig = ''; // check if there's an api version parameter let apiVersionParam; for (const param of consolidatedCtorParams) { switch (param.kind) { case 'headerScalarParam': case 'pathScalarParam': case 'queryScalarParam': case 'uriParam': if (param.isApiVersion) { apiVersionParam = param; } } } if (tokenAuth) { imports.add('github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud'); imports.add('fmt'); imports.add('reflect'); bodyText += '\tif reflect.ValueOf(options.Cloud).IsZero() {\n'; bodyText += '\t\toptions.Cloud = cloud.AzurePublic\n\t}\n'; bodyText += '\tc, ok := options.Cloud.Services[ServiceName]\n'; bodyText += '\tif !ok {\n'; bodyText += '\t\treturn nil, fmt.Errorf("provided Cloud field is missing configuration for %s", ServiceName)\n'; bodyText += '\t} else if c.Audience == "" {\n'; bodyText += '\t\treturn nil, fmt.Errorf("provided Cloud field is missing Audience for %s", ServiceName)\n\t}\n'; } if (apiVersionParam) { let location; let name; switch (apiVersionParam.kind) { case 'headerScalarParam': location = 'Header'; name = apiVersionParam.headerName; break; case 'pathScalarParam': case 'uriParam': location = 'Path'; // name isn't used for the path case break; case 'queryScalarParam': location = 'QueryParam'; name = apiVersionParam.queryParameter; break; } if (name) { name = `\n\t\t\tName: "${name}",`; } else { name = ''; } apiVersionConfig = `\n\t\tAPIVersion: runtime.APIVersionOptions{${name}\n\t\t\tLocation: runtime.APIVersionLocation${location},\n\t\t},`; if (!plOpts) { apiVersionConfig += '\n'; } } bodyText += `\tcl, err := azcore.NewClient(moduleName, moduleVersion, runtime.PipelineOptions{${apiVersionConfig}${plOpts ?? ''}}, &options.ClientOptions)\n`; return bodyText; }; // check if there's a credential parameter let credentialParam; for (const param of constructor.parameters) { if (param.kind === 'credentialParam') { credentialParam = param; break; } } let prolog; if (credentialParam) { switch (credentialParam.type.kind) { case 'tokenCredential': imports.add('github.com/Azure/azure-sdk-for-go/sdk/azcore'); paramDocs.push(helpers.formatCommentAsBulletItem('credential', { summary: 'used to authorize requests. Usually a credential from azidentity.' })); switch (clientOptions.kind) { case 'clientOptions': { imports.add('github.com/Azure/azure-sdk-for-go/sdk/azcore/policy'); const tokenPolicyOpts = '&policy.BearerTokenOptions{\n\t\t\tInsecureAllowCredentialWithHTTP: options.InsecureAllowCredentialWithHTTP,\n\t\t}'; // we assume a single scope. this is enforced when adapting the data from tcgc const tokenPolicy = `\n\t\tPerCall: []policy.Policy{\n\t\truntime.NewBearerTokenPolicy(credential, []string{c.Audience + "${helpers.splitScope(credentialParam.type.scopes[0]).scope}"}, ${tokenPolicyOpts}),\n\t\t},\n`; prolog = emitProlog(go.getTypeDeclaration(clientOptions), true, tokenPolicy); break; } case 'armClientOptions': // this is the ARM case imports.add('github.com/Azure/azure-sdk-for-go/sdk/azcore/arm'); prolog = '\tcl, err := arm.NewClient(moduleName, moduleVersion, credential, options)\n'; break; } break; } } else { prolog = emitProlog(go.getTypeDeclaration(clientOptions), false); } // add client options last ctorParams.push(`options ${helpers.formatParameterTypeName(clientOptions)}`); paramDocs.push(helpers.formatCommentAsBulletItem('options', { summary: 'Contains optional client configuration. Pass nil to accept the default values.' })); ctorText += `// ${constructor.name} creates a new instance of ${client.name} with the specified values.\n`; for (const doc of paramDocs) { ctorText += doc; } ctorText += `func ${constructor.name}(${ctorParams.join(', ')}) (*${client.name}, error) {\n`; ctorText += prolog; ctorText += '\tif err != nil {\n'; ctorText += '\t\treturn nil, err\n'; ctorText += '\t}\n'; // handle any client-side defaults if (clientOptions.kind === 'clientOptions') { for (const param of clientOptions.parameters) { if (go.isClientSideDefault(param.style)) { let name; if (go.isAPIVersionParameter(param)) { name = 'APIVersion'; } else { name = ensureNameCase(param.name); } ctorText += `\t${param.name} := ${helpers.formatLiteralValue(param.style.defaultValue, false)}\n`; ctorText += `\tif options.${name} != ${helpers.zeroValue(param)} {\n\t\t${param.name} = ${helpers.star(param.byValue)}options.${name}\n\t}\n`; } } } // construct the supplemental path and join it to the endpoint if (client.instance.endpoint?.supplemental) { imports.add('strings'); ctorText += `\thost := "${client.instance.endpoint.supplemental.path}"\n`; for (const param of client.instance.endpoint.supplemental.parameters) { ctorText += `\thost = strings.ReplaceAll(host, "{${param.uriPathSegment}}", ${helpers.formatValue(param.name, param.type, imports)})\n`; } // the endpoint param is always the first ctor param const endpointParam = client.instance.constructors[0].parameters[0]; ctorText += `\t${endpointParam.name} = runtime.JoinPaths(${endpointParam.name}, host)\n`; } // construct client literal let clientVar = 'client'; // ensure clientVar doesn't collide with any params for (const param of consolidatedCtorParams) { if (param.name === clientVar) { clientVar = ensureNameCase(client.name, true); break; } } ctorText += `\t${clientVar} := &${client.name}{\n`; // NOTE: we don't enumerate consolidatedCtorParams here // as any supplemental endpoint params are ephemeral and // consumed during client construction. for (const parameter of values(client.parameters)) { if (go.isLiteralParameter(parameter.style)) { continue; } // each client field will have a matching parameter with the same name ctorText += `\t\t${parameter.name}: ${parameter.name},\n`; } ctorText += '\tinternal: cl,\n'; ctorText += '\t}\n'; ctorText += `\treturn ${clientVar}, nil\n`; ctorText += '}\n\n'; } return ctorText; } // use this to generate the code that will help process values returned in response headers function formatHeaderResponseValue(headerResp, imports, respObj, zeroResp) { // dictionaries are handled slightly different so we do that first if (headerResp.kind === 'headerMapResponse') { imports.add('github.com/Azure/azure-sdk-for-go/sdk/azcore/to'); imports.add('strings'); const headerPrefix = headerResp.headerName; let text = '\tfor hh := range resp.Header {\n'; text += `\t\tif len(hh) > len("${headerPrefix}") && strings.EqualFold(hh[:len("${headerPrefix}")], "${headerPrefix}") {\n`; text += `\t\t\tif ${respObj}.${headerResp.fieldName} == nil {\n`; text += `\t\t\t\t${respObj}.${headerResp.fieldName} = map[string]*string{}\n`; text += '\t\t\t}\n'; text += `\t\t\t${respObj}.${headerResp.fieldName}[hh[len("${headerPrefix}"):]] = to.Ptr(resp.Header.Get(hh))\n`; text += '\t\t}\n'; text += '\t}\n'; return text; } let text = `\tif val := resp.Header.Get("${headerResp.headerName}"); val != "" {\n`; let name = uncapitalize(headerResp.fieldName); let byRef = '&'; switch (headerResp.type.kind) { case 'constant': text += `\t\t${respObj}.${headerResp.fieldName} = (*${headerResp.type.name})(&val)\n`; text += '\t}\n'; return text; case 'encodedBytes': // a base-64 encoded value in string format imports.add('encoding/base64'); text += `\t\t${name}, err := base64.${helpers.formatBytesEncoding(headerResp.type.encoding)}Encoding.DecodeString(val)\n`; byRef = ''; break; case 'literal': text += `\t\t${respObj}.${headerResp.fieldName} = &val\n`; text += '\t}\n'; return text; case 'scalar': imports.add('strconv'); switch (headerResp.type.type) { case 'bool': text += `\t\t${name}, err := strconv.ParseBool(val)\n`; break; case 'float32': text += `\t\t${name}32, err := strconv.ParseFloat(val, 32)\n`; text += `\t\t${name} := float32(${name}32)\n`; break; case 'float64': text += `\t\t${name}, err := strconv.ParseFloat(val, 64)\n`; break; case 'int32': text += `\t\t${name}32, err := strconv.ParseInt(val, 10, 32)\n`; text += `\t\t${name} := int32(${name}32)\n`; break; case 'int64': text += `\t\t${name}, err := strconv.ParseInt(val, 10, 64)\n`; break; default: throw new CodegenError('InternalError', `unhandled scalar type ${headerResp.type.type}`); } break; case 'string': text += `\t\t${respObj}.${headerResp.fieldName} = &val\n`; text += '\t}\n'; return text; case 'time': imports.add('time'); switch (headerResp.type.format) { case 'dateTimeRFC1123': case 'dateTimeRFC3339': text += `\t\t${name}, err := time.Parse(${headerResp.type.format === 'dateTimeRFC1123' ? helpers.datetimeRFC1123Format : helpers.datetimeRFC3339Format}, val)\n`; break; case 'dateType': text += `\t\t${name}, err := time.Parse("${helpers.dateFormat}", val)\n`; break; case 'timeRFC3339': text += `\t\t${name}, err := time.Parse("${helpers.timeRFC3339Format}", val)\n`; break; case 'timeUnix': imports.add('strconv'); imports.add('github.com/Azure/azure-sdk-for-go/sdk/azcore/to'); text += '\t\tsec, err := strconv.ParseInt(val, 10, 64)\n'; name = 'to.Ptr(time.Unix(sec, 0))'; byRef = ''; break; } } // NOTE: only cases that required parsing will fall through to here text += '\t\tif err != nil {\n'; text += `\t\t\treturn ${zeroResp}, err\n`; text += '\t\t}\n'; text += `\t\t${respObj}.${headerResp.fieldName} = ${byRef}${name}\n`; text += '\t}\n'; return text; } function getZeroReturnValue(method, apiType) { let returnType = `${method.returns.name}{}`; if (go.isLROMethod(method)) { if (apiType === 'api' || apiType === 'op') { // the api returns a *Poller[T] // the operation returns an *http.Response returnType = 'nil'; } } return returnType; } // Helper function to generate nil checks for a dotted path function generateNilChecks(path, prefix = 'page') { const segments = path.split('.'); const checks = []; for (let i = 0; i < segments.length; i++) { const currentPath = [prefix, ...segments.slice(0, i + 1)].join('.'); checks.push(`${currentPath} != nil`); } return checks.join(' && '); } function emitPagerDefinition(method, imports, injectSpans, generateFakes) { imports.add('context'); let text = `runtime.NewPager(runtime.PagingHandler[${method.returns.name}]{\n`; text += `\t\tMore: func(page ${method.returns.name}) bool {\n`; // there is no advancer for single-page pagers if (method.nextLinkName) { const nilChecks = generateNilChecks(method.nextLinkName); text += `\t\t\treturn ${nilChecks} && len(*page.${method.nextLinkName}) > 0\n`; text += '\t\t},\n'; } else { text += '\t\t\treturn false\n'; text += '\t\t},\n'; } text += `\t\tFetcher: func(ctx context.Context, page *${method.returns.name}) (${method.returns.name}, error) {\n`; const reqParams = helpers.getCreateRequestParameters(method); if (generateFakes) { text += `\t\tctx = context.WithValue(ctx, runtime.CtxAPINameKey{}, "${method.receiver.type.name}.${fixUpMethodName(method)}")\n`; } if (method.nextLinkName) { let nextLinkVar; if (method.kind === 'pageableMethod') { text += '\t\t\tnextLink := ""\n'; nextLinkVar = 'nextLink'; text += '\t\t\tif page != nil {\n'; text += `\t\t\t\tnextLink = *page.${method.nextLinkName}\n\t\t\t}\n`; } else { nextLinkVar = `*page.${method.nextLinkName}`; } text += `\t\t\tresp, err := runtime.FetcherForNextLink(ctx, client.internal.Pipeline(), ${nextLinkVar}, func(ctx context.Context) (*policy.Request, error) {\n`; text += `\t\t\t\treturn client.${method.naming.requestMethod}(${reqParams})\n\t\t\t}, `; // nextPageMethod might be absent in some cases, see https://github.com/Azure/autorest/issues/4393 if (method.nextPageMethod) { const nextOpParams = helpers.getCreateRequestParametersSig(method.nextPageMethod).split(','); // keep the parameter names from the name/type tuples and find nextLink param for (let i = 0; i < nextOpParams.length; ++i) { const paramName = nextOpParams[i].trim().split(' ')[0]; const paramType = nextOpParams[i].trim().split(' ')[1]; if (paramName.startsWith('next') && paramType === 'string') { nextOpParams[i] = 'encodedNextLink'; } else { nextOpParams[i] = paramName; } } // add a definition for the nextReq func that uses the nextLinkOperation text += '&runtime.FetcherForNextLinkOptions{\n\t\t\t\tNextReq: func(ctx context.Context, encodedNextLink string) (*policy.Request, error) {\n'; text += `\t\t\t\t\treturn client.${method.nextPageMethod.name}(${nextOpParams.join(', ')})\n\t\t\t\t},\n\t\t\t})\n`; } else { text += 'nil)\n'; } text += `\t\t\tif err != nil {\n\t\t\t\treturn ${method.returns.name}{}, err\n\t\t\t}\n`; text += `\t\t\treturn client.${method.naming.responseMethod}(resp)\n`; text += '\t\t\t},\n'; } else { // this is the singular page case, no fetcher helper required text += `\t\t\treq, err := client.${method.naming.requestMethod}(${reqParams})\n`; text += '\t\t\tif err != nil {\n'; text += `\t\t\t\treturn ${method.returns.name}{}, err\n`; text += '\t\t\t}\n'; text += '\t\t\tresp, err := client.internal.Pipeline().Do(req)\n'; text += '\t\t\tif err != nil {\n'; text += `\t\t\t\treturn ${method.returns.name}{}, err\n`; text += '\t\t\t}\n'; text += '\t\t\tif !runtime.HasStatusCode(resp, http.StatusOK) {\n'; text += `\t\t\t\treturn ${method.returns.name}{}, runtime.NewResponseError(resp)\n`; text += '\t\t\t}\n'; text += `\t\t\treturn client.${method.naming.responseMethod}(resp)\n`; text += '\t\t},\n'; } if (injectSpans) { text += '\t\tTracer: client.internal.Tracer(),\n'; } text += '\t})\n'; return text; } function genApiVersionDoc(apiVersions) { if (apiVersions.length === 0) { return ''; } return `//\n// Generated from API version ${apiVersions.join(', ')}\n`; } function genRespErrorDoc(method) { if (!(method.returns.result?.kind === 'headAsBooleanResult') && !go.isPageableMethod(method)) { // when head-as-boolean is enabled, no error is returned for 4xx status codes. // pager constructors don't return an error return '// If the operation fails it returns an *azcore.ResponseError type.\n'; } return ''; } /** * returns the receiver definition for a client * * @param receiver the receiver for which to emit the definition * @returns the receiver definition */ function getClientReceiverDefinition(receiver) { return `(${receiver.name} ${receiver.byValue ? '' : '*'}${receiver.type.name})`; } function generateOperation(method, imports, injectSpans, generateFakes) { const params = getAPIParametersSig(method, imports); const returns = generateReturnsInfo(method, 'op'); let methodName = method.name; if (method.kind === 'pageableMethod') { methodName = fixUpMethodName(method); } let text = ''; const respErrDoc = genRespErrorDoc(method); const apiVerDoc = genApiVersionDoc(method.apiVersions); if (method.docs.summary || method.docs.description) { text += helpers.formatDocCommentWithPrefix(methodName, method.docs); } else if (respErrDoc.length > 0 || apiVerDoc.length > 0) { // if the method has no doc comment but we're adding other // doc comments, add an empty method name comment. this preserves // existing behavior and makes the docs look better overall. text += `// ${methodName} -\n`; } text += respErrDoc; text += apiVerDoc; if (go.isLROMethod(method)) { methodName = method.naming.internalMethod; } else { for (const param of values(helpers.getMethodParameters(method))) { text += helpers.formatCommentAsBulletItem(param.name, param.docs); } } text += `func ${getClientReceiverDefinition(method.receiver)} ${methodName}(${params}) (${returns.join(', ')}) {\n`; const reqParams = helpers.getCreateRequestParameters(method); if (method.kind === 'pageableMethod') { text += '\treturn '; text += emitPagerDefinition(method, imports, injectSpans, generateFakes); text += '}\n\n'; return text; } text += '\tvar err error\n'; let operationName = `"${method.receiver.type.name}.${fixUpMethodName(method)}"`; if (generateFakes && injectSpans) { text += `\tconst operationName = ${operationName}\n`; operationName = 'operationName'; } if (generateFakes) { text += `\tctx = context.WithValue(ctx, runtime.CtxAPINameKey{}, ${operationName})\n`; } if (injectSpans) { text += `\tctx, endSpan := runtime.StartSpan(ctx, ${operationName}, client.internal.Tracer(), nil)\n`; text += '\tdefer func() { endSpan(err) }()\n'; } const zeroResp = getZeroReturnValue(method, 'op'); text += `\treq, err := client.${method.naming.requestMethod}(${reqParams})\n`; text += '\tif err != nil {\n'; text += `\t\treturn ${zeroResp}, err\n`; text += '\t}\n'; text += '\thttpResp, err := client.internal.Pipeline().Do(req)\n'; text += '\tif err != nil {\n'; text += `\t\treturn ${zeroResp}, err\n`; text += '\t}\n'; text += `\tif !runtime.HasStatusCode(httpResp, ${helpers.formatStatusCodes(method.httpStatusCodes)}) {\n`; text += '\t\terr = runtime.NewResponseError(httpResp)\n'; text += `\t\treturn ${zeroResp}, err\n`; text += '\t}\n'; // HAB with headers response is handled in protocol responder if (method.returns.result?.kind === 'headAsBooleanResult' && method.returns.headers.length === 0) { text += `\treturn ${method.returns.name}{${method.returns.result.fieldName}: httpResp.StatusCode >= 200 && httpResp.StatusCode < 300}, nil\n`; } else { if (go.isLROMethod(method)) { text += '\treturn httpResp, nil\n'; } else if (needsResponseHandler(method)) { // also cheating here as at present the only param to the responder is an http.Response text += `\tresp, err := client.${method.naming.responseMethod}(httpResp)\n`; text += '\treturn resp, err\n'; } else if (method.returns.result?.kind === 'binaryResult') { text += `\treturn ${method.returns.name}{${method.returns.result.fieldName}: httpResp.Body}, nil\n`; } else { text += `\treturn ${method.returns.name}{}, nil\n`; } } text += '}\n\n'; return text; } function createProtocolRequest(azureARM, method, imports) { let name = method.name; if (method.kind !== 'nextPageMethod') { name = method.naming.requestMethod; } for (const param of values(method.parameters)) { if (param.location !== 'method' || !go.isRequiredParameter(param.style)) { continue; } imports.addImportForType(param.type); } const returns = ['*policy.Request', 'error']; let text = `${comment(name, '// ')} creates the ${method.name} request.\n`; text += `func ${getClientReceiverDefinition(method.receiver)} ${name}(${helpers.getCreateRequestParametersSig(method)}) (${returns.join(', ')}) {\n`; const hostParams = new Array(); for (const parameter of method.receiver.type.parameters) { if (parameter.kind === 'uriParam') { hostParams.push(parameter); } } let hostParam; if (azureARM) { hostParam = 'client.internal.Endpoint()'; } else if (method.receiver.type.instance?.kind === 'templatedHost') { imports.add('strings'); // we have a templated host text += `\thost := "${method.receiver.type.instance.path}"\n`; // get all the host params on the client for (const hostParam of hostParams) { text += `\thost = strings.ReplaceAll(host, "{${hostParam.uriPathSegment}}", ${helpers.formatValue(`client.${hostParam.name}`, hostParam.type, imports)})\n`; } // check for any method local host params for (const param of values(method.parameters)) { if (param.location === 'method' && param.kind === 'uriParam') { text += `\thost = strings.ReplaceAll(host, "{${param.uriPathSegment}}", ${helpers.formatValue(helpers.getParamName(param), param.type, imports)})\n`; } } hostParam = 'host'; } else if (hostParams.length === 1) { // simple parameterized host case hostParam = 'client.' + hostParams[0].name; } else { throw new CodegenError('InternalError', `no host or endpoint defined for method ${method.receiver.type.name}.${method.name}`); } const methodParamGroups = helpers.getMethodParamGroups(method); const hasPathParams = methodParamGroups.pathParams.length > 0; // storage needs the client.u to be the source-of-truth for the full path. // however, swagger requires that all operations specify a path, which is at odds with storage. // to work around this, storage specifies x-ms-path paths with path params but doesn't // actually reference the path params (i.e. no params with which to replace the tokens). // so, if a path contains tokens but there are no path params, skip emitting the path. const pathStr = method.httpPath; const pathContainsParms = pathStr.includes('{'); if (hasPathParams || (!pathContainsParms && pathStr.length > 1)) { // there are path params, or the path doesn't contain tokens and is not "/" so emit it text += `\turlPath := "${method.httpPath}"\n`; hostParam = `runtime.JoinPaths(${hostParam}, urlPath)`; } // helper to build nil checks for param groups const emitParamGroupCheck = function (param) { if (!param.group) { throw new CodegenError('InternalError', `emitParamGroupCheck called for ungrouped parameter ${param.name}`); } let client = ''; if (param.location === 'client') { client = 'client.'; } const paramGroupName = uncapitalize(param.group.name); let optionalParamGroupCheck = `${client}${paramGroupName} != nil && `; if (param.group.required) { optionalParamGroupCheck = ''; } return `\tif ${optionalParamGroupCheck}${client}${paramGroupName}.${capitalize(param.name)} != nil {\n`; }; if (hasPathParams) { // swagger defines path params, emit path and replace tokens imports.add('strings'); // replace path parameters for (const pp of methodParamGroups.pathParams) { let paramValue; let optionalPathSep = false; if (pp.style !== 'optional') { // emit check to ensure path param isn't an empty string if (pp.kind === 'pathScalarParam') { const choiceIsString = function (type) { return type.kind === 'constant' && type.type === 'string'; }; // we only need to do this for params that have an underlying type of string if ((pp.type.kind === 'string' || choiceIsString(pp.type)) && !pp.omitEmptyStringCheck) { const paramName = helpers.getParamName(pp); imports.add('errors'); text += `\tif ${paramName} == "" {\n`; text += `\t\treturn nil, errors.New("parameter ${paramName} cannot be empty")\n`; text += '\t}\n'; } } paramValue = helpers.formatParamValue(pp, imports); // for collection-based path params, we emit the empty check // after calling helpers.formatParamValue as that will have the // var name that contains the slice. if (pp.kind === 'pathCollectionParam') { const paramName = helpers.getParamName(pp); const joinedParamName = `${paramName}Param`; text += `\t${joinedParamName} := ${paramValue}\n`; imports.add('errors'); text += `\tif len(${joinedParamName}) == 0 {\n`; text += `\t\treturn nil, errors.New("parameter ${paramName} cannot be empty")\n`; text += '\t}\n'; paramValue = joinedParamName; } } else { // param isn't required, so emit a local var with // the correct default value, then populate it with // the optional value when set. paramValue = `optional${capitalize(pp.name)}`; text += `\t${paramValue} := ""\n`; text += emitParamGroupCheck(pp); text += `\t${paramValue} = ${helpers.formatParamValue(pp, imports)}\n\t}\n`; // there are two cases for optional path params. // - /foo/bar/{optional} // - /foo/bar{/optional} // for the second case, we need to include a forward slash if (method.httpPath[method.httpPath.indexOf(`{${pp.pathSegment}}`) - 1] !== '/') { optionalPathSep = true; } } const emitPathEscape = function () { if (pp.isEncoded) { imports.add('net/url'); return `url.PathEscape(${paramValue})`; } return paramValue; }; if (optionalPathSep) { text += `\tif len(${paramValue}) > 0 {\n`; text += `\t\t${paramValue} = "/"+${emitPathEscape()}\n`; text += '\t}\n'; } else { paramValue = emitPathEscape(); } text += `\turlPath = strings.ReplaceAll(urlPath, "{${pp.pathSegment}}", ${paramValue})\n`; } } text += `\treq, err := runtime.NewRequest(ctx, http.Method${capitalize(method.httpMethod)}, ${hostParam})\n`; text += '\tif err != nil {\n'; text += '\t\treturn nil, err\n'; text += '\t}\n'; // add query parameters const encodedParams = methodParamGroups.encodedQueryParams; const unencodedParams = methodParamGroups.unencodedQueryParams; const emitQueryParam = function (qp, setter) { let qpText = ''; if (qp.location === 'method' && go.isClientSideDefault(qp.style)) { qpText = emitClientSideDefault(qp, qp.style, (name, val) => { return `\treqQP.Set(${name}, ${val})`; }, imports); } else if (go.isRequiredParameter(qp.style) || go.isLiteralParameter(qp.style) || (qp.location === 'client' && go.isClientSideDefault(qp.style))) { qpText = `\t${setter}\n`; } else if (qp.location === 'client' && !qp.group) { // global optional param qpText = `\tif client.${qp.name} != nil {\n`; qpText += `\t\t${setter}\n`; qpText += '\t}\n'; } else { qpText = emitParamGroupCheck(qp); qpText += `\t\t${setter}\n`; qpText += '\t}\n'; } return qpText; }; // emit encoded params first if (encodedParams.length > 0) { text += '\treqQP := req.Raw().URL.Query()\n'; for (const qp of values(encodedParams.sort((a, b) => { return helpers.sortAscending(a.queryParameter, b.queryParameter); }))) { let setter; if (qp.kind === 'queryCollectionParam' && qp.collectionFormat === 'multi') { setter = `\tfor _, qv := range ${helpers.getParamName(qp)} {\n`; // emit a type conversion for the qv based on the array's element type let queryVal; const arrayQP = qp.type; switch (arrayQP.elementType.kind) { case 'constant': switch (arrayQP.elementType.type) { case 'string': queryVal = 'string(qv)'; break; default: imports.add('fmt'); queryVal = 'fmt.Sprintf("%d", qv)'; } break; case 'string': queryVal = 'qv'; break; default: imports.add('fmt'); queryVal = 'fmt.Sprintf("%v", qv)'; } setter += `\t\treqQP.Add("${qp.queryParameter}", ${queryVal})\n`; setter += '\t}'; } else { // cannot initialize setter to this value as helpers.formatParamValue() can change imports setter = `reqQP.Set("${qp.queryParameter}", ${helpers.formatParamValue(qp, imports)})`; } text += emitQueryParam(qp, setter); } text += '\treq.Raw().URL.RawQuery = reqQP.Encode()\n'; } // tack on any unencoded params to the end if (unencodedParams.length > 0) { if (encodedParams.length > 0) { text += '\tunencodedParams := []string{req.Raw().URL.RawQuery}\n'; } else { text += '\tunencodedParams := []string{}\n'; } for (const qp of values(unencodedParams.sort((a, b) => { return helpers.sortAscending(a.queryParameter, b.queryParameter); }))) { let setter; if (qp.kind === 'queryCollectionParam' && qp.collectionFormat === 'multi') { setter = `\tfor _, qv := range ${helpers.getParamName(qp)} {\n`; setter += `\t\tunencodedParams = append(unencodedParams, "${qp.queryParameter}="+qv)\n`; setter += '\t}'; } else { setter = `unencodedParams = append(unencodedParams, "${qp.queryParameter}="+${helpers.formatParamValue(qp, imports)})`; } text += emitQueryParam(qp, setter); } imports.add('strings'); text += '\treq.Raw().URL.RawQuery = strings.Join(unencodedParams, "&")\n'; } if (method.kind !== 'nextPageMethod' && method.returns.result?.kind === 'binaryResult') { // skip auto-body downloading for binary stream responses text += '\truntime.SkipBodyDownload(req)\n'; } // add specific request headers const emitHeaderSet = function (headerParam, prefix) { if (headerParam.kind === 'headerMapParam') { let headerText = `${prefix}for k, v := range ${helpers.getParamName(headerParam)} {\n`; headerText += `${prefix}\tif v != nil {\n`; headerText += `${prefix}\t\treq.Raw().Header["${headerParam.headerName}"+k] = []string{*v}\n`; headerText += `${prefix}}\n`; headerText += `${prefix}}\n`; return headerText; } else if (headerParam.location === 'method' && go.isClientSideDefault(headerParam.style)) { return emitClientSideDefault(headerParam, headerParam.style, (name, val) => { return `${prefix}req.Raw().Header[${name}] = []string{${val}}`; }, imports); } else { return `${prefix}req.Raw().Header["${headerParam.headerName}"] = []string{${helpers.formatParamValue(headerParam, imports)}}\n`; } }; let contentType; for (const param of methodParamGroups.headerParams.sort((a, b) => { return helpers.sortAscending(a.headerName, b.headerName); })) { if (param.headerName.match(/^content-type$/)) { // canonicalize content-type as req.SetBody checks for it via its canonicalized name :( param.headerName = 'Content-Type'; } if (param.headerName === 'Content-Type' && param.style === 'literal') { // the content-type header will be set as part of emitSetBodyWithErrCheck // to handle cases where the body param is optional. we don't want to set // the content-type if the body is nil. // we do it like this as tsp specifies content-type while swagger does not. contentType = helpers.formatParamValue(param, imports); } else if (go.isRequiredParameter(param.style) || go.isLiteralParameter(param.style) || go.isClientSideDefault(param.style)) { text += emitHeaderSet(param, '\t'); } else if (param.location === 'client' && !param.group) { // global optional param text += `\tif client.${param.name} != nil {\n`; text += emitHeaderSet(param, '\t'); text += '\t}\n'; } else { text += emitParamGroupCheck(param); text += emitHeaderSet(param, '\t\t'); text += '\t}\n'; } } // note that these are mutually exclusive const bodyParam = methodParamGroups.bodyParam; const formBodyParams = methodParamGroups.formBodyParams; const multipartBodyParams = methodParamGroups.multipartBodyParams; const partialBodyParams = methodParamGroups.partialBodyParams; const emitSetBodyWithErrCheck = function (setBodyParam, contentType) { let content = `if err := ${setBodyParam}; err != nil {\n\treturn nil, err\n}\n;`; if (contentType) { content = `req.Raw().Header["Content-Type"] = []string{${contentType}}\n` + content; } return content; }; if (bodyParam) { if (bodyParam.bodyFormat === 'JSON' || bodyParam.bodyFormat === 'XML') { // default to the body param name let body = helpers.getParamName(bodyParam); if (bodyParam.type.kind === 'literal') { // if the value is constant, embed it directly body = helpers.formatLiteralValue(bodyParam.type, true); } else if (bodyParam.bodyFormat === 'XML' && bodyParam.type.kind === 'slice') { // for XML payloads, create a wrapper type if the payload is an array imports.add('encoding/xml'); text += '\ttype wrapper struct {\n'; let tagName = go.getTypeDeclaration(bodyParam.type); if (bodyParam.xml?.name) { tagName = bodyParam.xml.name; } text += `\t\tXMLName xml.Name \`xml:"${tagName}"\`\n`; const fieldName = capitalize(bodyParam.name); let tag = go.getTypeDeclaration(bodyParam.type.elementType); if (bodyParam.type.elementType.kind === 'model' && bodyParam.type.elementType.xml?.name) { tag = bodyParam.type.elementType.xml.name; } text += `\t\t${fieldName} *${go.getTypeDeclaration(bodyParam.type)} \`xml:"${tag}"\`\n`; text += '\t}\n'; let addr = '&'; if (!go.isRequiredParameter(bodyParam.style) && !bodyParam.byValue) { addr = ''; } body = `wrapper{${fieldName}: ${addr}${body}}`; } else if (bodyParam.type.kind === 'time' && bodyParam.type.format !== 'dateTimeRFC3339') { // wrap the body in the internal time type // no need for dateTimeRFC3339 as the JSON marshaler defaults to that. body = `${bodyParam.type.format}(${body})`; } else if (isArrayOfDateTimeForMarshalling(bodyParam.type)) { const timeInfo = isArrayOfDateTimeForMarshalling(bodyParam.type); let elementPtr = '*';