@autorest/go
Version:
AutoRest Go Generator
909 lines • 72.1 kB
JavaScript
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import { camelCase, capitalize, uncapitalize } from '@azure-tools/codegen';
import { values } from '@azure-tools/linq';
import * as go from '../../../codemodel.go/src/index.js';
import * as helpers from '../helpers.js';
import { ImportManager } from '../imports.js';
import { fixUpMethodName } from '../operations.js';
import { generateServerInternal, RequiredHelpers } from './internal.js';
import { CodegenError } from '../errors.js';
// contains the generated content for all servers and the required helpers
export class ServerContent {
servers;
internals;
constructor(servers, internals) {
this.servers = servers;
this.internals = internals;
}
}
// represents the generated content for an operation group
export class OperationGroupContent {
name;
content;
constructor(name, content) {
this.name = name;
this.content = content;
}
}
// used to track the helpers we need to emit. they're all false by default.
const requiredHelpers = new RequiredHelpers();
export function getServerName(client) {
// for the fake server, we use the suffix Server instead of Client
return capitalize(client.name.replace(/[C|c]lient$/, 'Server'));
}
export async function generateServers(codeModel) {
const operations = new Array();
const clientPkg = codeModel.packageName;
for (const client of values(codeModel.clients)) {
if (client.clientAccessors.length === 0 && values(client.methods).all(method => { return helpers.isMethodInternal(method); })) {
// client has no client accessors and no exported methods, skip it
continue;
}
// the list of packages to import
const imports = new ImportManager();
// add standard imports
imports.add('errors');
imports.add('fmt');
imports.add('net/http');
imports.add('github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime');
const serverName = getServerName(client);
let content;
content = `// ${serverName} is a fake server for instances of the ${clientPkg}.${client.name} type.\n`;
content += `type ${serverName} struct{\n`;
// we might remove some operations from the list
const finalMethods = new Array();
let countLROs = 0;
let countPagers = 0;
// add server transports for client accessors
// we might remove some clients from the list
const finalSubClients = new Array();
for (const clientAccessor of client.clientAccessors) {
if (values(clientAccessor.subClient.methods).all(method => { return helpers.isMethodInternal(method); })) {
// client has no exported methods, skip it
continue;
}
const serverName = getServerName(clientAccessor.subClient);
content += `\t// ${serverName} contains the fakes for client ${clientAccessor.subClient.name}\n`;
content += `\t${serverName} ${serverName}\n\n`;
finalSubClients.push(clientAccessor.subClient);
}
for (const method of values(client.methods)) {
if (helpers.isMethodInternal(method)) {
// method isn't exported, don't create a fake for it
continue;
}
let serverResponse;
switch (method.kind) {
case 'lroMethod':
case 'lroPageableMethod':
let respType = `${clientPkg}.${method.returns.name}`;
if (method.kind === 'lroPageableMethod') {
respType = `azfake.PagerResponder[${clientPkg}.${method.returns.name}]`;
}
serverResponse = `resp azfake.PollerResponder[${respType}], errResp azfake.ErrorResponder`;
break;
case 'method':
serverResponse = `resp azfake.Responder[${clientPkg}.${method.returns.name}], errResp azfake.ErrorResponder`;
break;
case 'pageableMethod':
serverResponse = `resp azfake.PagerResponder[${clientPkg}.${method.returns.name}]`;
break;
}
const operationName = fixUpMethodName(method);
content += `\t// ${operationName} is the fake for method ${client.name}.${operationName}\n`;
const successCodes = new Array();
if (method.returns.result?.kind === 'anyResult') {
for (const httpStatus of getMethodStatusCodes(method)) {
const result = method.returns.result.httpStatusCodeType[httpStatus];
if (!result) {
// the operation contains a mix of schemas and non-schema responses
successCodes.push(`${helpers.formatStatusCode(httpStatus)} (no return type)`);
continue;
}
successCodes.push(`${helpers.formatStatusCode(httpStatus)} (returns ${go.getTypeDeclaration(result, clientPkg)})`);
}
content += '\t// HTTP status codes to indicate success:\n';
for (const successCode of successCodes) {
content += `\t// - ${successCode}\n`;
}
}
else {
for (const statusCode of getMethodStatusCodes(method)) {
successCodes.push(`${helpers.formatStatusCode(statusCode)}`);
}
content += `\t// HTTP status codes to indicate success: ${successCodes.join(', ')}\n`;
}
content += `\t${operationName} func(${getAPIParametersSig(method, imports, clientPkg)}) (${serverResponse})\n\n`;
finalMethods.push(method);
switch (method.kind) {
case 'lroMethod':
case 'lroPageableMethod':
++countLROs;
break;
case 'pageableMethod':
++countPagers;
break;
}
}
content += '}\n\n';
///////////////////////////////////////////////////////////////////////////
const serverTransport = `${serverName}Transport`;
content += `// New${serverTransport} creates a new instance of ${serverTransport} with the provided implementation.\n`;
content += `// The returned ${serverTransport} instance is connected to an instance of ${clientPkg}.${client.name} via the\n`;
content += '// azcore.ClientOptions.Transporter field in the client\'s constructor parameters.\n';
content += `func New${serverTransport}(srv *${serverName}) *${serverTransport} {\n`;
if (countLROs === 0 && countPagers === 0) {
content += `\treturn &${serverTransport}{srv: srv}\n}\n\n`;
}
else {
content += `\treturn &${serverTransport}{\n\t\tsrv: srv,\n`;
for (const method of values(finalMethods)) {
let respType = `${clientPkg}.${method.returns.name}`;
switch (method.kind) {
case 'lroMethod':
case 'lroPageableMethod':
if (method.kind === 'lroPageableMethod') {
respType = `azfake.PagerResponder[${clientPkg}.${method.returns.name}]`;
}
requiredHelpers.tracker = true;
content += `\t\t${uncapitalize(fixUpMethodName(method))}: newTracker[azfake.PollerResponder[${respType}]](),\n`;
break;
case 'pageableMethod':
requiredHelpers.tracker = true;
content += `\t\t${uncapitalize(fixUpMethodName(method))}: newTracker[azfake.PagerResponder[${respType}]](),\n`;
break;
}
}
content += '\t}\n}\n\n';
}
content += `// ${serverTransport} connects instances of ${clientPkg}.${client.name} to instances of ${serverName}.\n`;
content += `// Don't use this type directly, use New${serverTransport} instead.\n`;
content += `type ${serverTransport} struct {\n`;
content += `\tsrv *${serverName}\n`;
// add server transports for client accessors
if (finalSubClients.length > 0) {
requiredHelpers.initServer = true;
imports.add('sync');
content += '\ttrMu sync.Mutex\n';
for (const subClient of finalSubClients) {
const serverName = getServerName(subClient);
content += `\ttr${serverName} *${serverName}Transport\n`;
}
}
for (const method of values(finalMethods)) {
// create state machines for any pager/poller operations
let respType = `${clientPkg}.${method.returns.name}`;
switch (method.kind) {
case 'lroMethod':
case 'lroPageableMethod':
if (method.kind === 'lroPageableMethod') {
respType = `azfake.PagerResponder[${clientPkg}.${method.returns.name}]`;
}
requiredHelpers.tracker = true;
content += `\t${uncapitalize(fixUpMethodName(method))} *tracker[azfake.PollerResponder[${respType}]]\n`;
break;
case 'pageableMethod':
requiredHelpers.tracker = true;
content += `\t${uncapitalize(fixUpMethodName(method))} *tracker[azfake.PagerResponder[${clientPkg}.${method.returns.name}]]\n`;
break;
}
}
content += '}\n\n';
content += generateServerTransportDo(serverTransport, client, finalSubClients, finalMethods);
content += generateServerTransportClientDispatch(serverTransport, finalSubClients, imports);
content += generateServerTransportMethodDispatch(serverTransport, client, finalMethods);
content += generateServerTransportMethods(codeModel, serverTransport, finalMethods, imports);
content += `// set this to conditionally intercept incoming requests to ${serverTransport}\n`;
content += `var ${getTransportInterceptorVarName(client)} interface {\n`;
content += '\t// Do returns true if the server transport should use the returned response/error\n';
content += '\tDo(*http.Request) (*http.Response, error, bool)\n}\n';
///////////////////////////////////////////////////////////////////////////
// stitch everything together
let text = helpers.contentPreamble(codeModel, true, 'fake');
text += imports.text();
text += content;
operations.push(new OperationGroupContent(serverName, text));
}
return new ServerContent(operations, generateServerInternal(codeModel, requiredHelpers));
}
function getTransportInterceptorVarName(client) {
return `${camelCase(getServerName(client))}TransportInterceptor`;
}
// method names for fakes dispatching
const dispatchMethodFake = 'dispatchToMethodFake';
const dispatchToClientFake = 'dispatchToClientFake';
function generateServerTransportDo(serverTransport, client, finalSubClients, finalMethods) {
const receiverName = serverTransport[0].toLowerCase();
let content = `// Do implements the policy.Transporter interface for ${serverTransport}.\n`;
content += `func (${receiverName} *${serverTransport}) Do(req *http.Request) (*http.Response, error) {\n`;
content += '\trawMethod := req.Context().Value(runtime.CtxAPINameKey{})\n';
content += '\tmethod, ok := rawMethod.(string)\n';
content += '\tif !ok {\n\t\treturn nil, nonRetriableError{errors.New("unable to dispatch request, missing value for CtxAPINameKey")}\n\t}\n\n';
if (finalSubClients.length > 0 && finalMethods.length > 0) {
// client contains client accessors and methods.
// if the method isn't for this client, dispatch to the correct client
content += `\tif client := method[:strings.Index(method, ".")]; client != "${client.name}" {\n`;
content += `\t\treturn ${receiverName}.${dispatchToClientFake}(req, client)\n\t}\n`;
// else dispatch to our method fakes
content += `\treturn ${receiverName}.${dispatchMethodFake}(req, method)\n`;
}
else if (finalSubClients.length > 0) {
content += `\treturn ${receiverName}.${dispatchToClientFake}(req, method[:strings.Index(method, ".")])\n`;
}
else {
content += `\treturn ${receiverName}.${dispatchMethodFake}(req, method)\n`;
}
content += '}\n\n'; // end Do
return content;
}
function generateServerTransportClientDispatch(serverTransport, subClients, imports) {
if (subClients.length === 0) {
return '';
}
const receiverName = serverTransport[0].toLowerCase();
imports.add('strings');
let content = `func (${receiverName} *${serverTransport}) ${dispatchToClientFake}(req *http.Request, client string) (*http.Response, error) {\n`;
content += '\tvar resp *http.Response\n\tvar err error\n\n';
content += '\tswitch client {\n';
for (const subClient of subClients) {
content += `\tcase "${subClient.name}":\n`;
const serverName = getServerName(subClient);
content += `\t\tinitServer(&${receiverName}.trMu, &${receiverName}.tr${serverName}, func() *${serverName}Transport {\n\t\treturn New${serverName}Transport(&${receiverName}.srv.${serverName}) })\n`;
content += `\t\tresp, err = ${receiverName}.tr${serverName}.Do(req)\n`;
}
content += '\tdefault:\n\t\terr = fmt.Errorf("unhandled client %s", client)\n';
content += '\t}\n\n'; // end switch
content += '\treturn resp, err\n}\n\n';
return content;
}
function generateServerTransportMethodDispatch(serverTransport, client, finalMethods) {
if (finalMethods.length === 0) {
return '';
}
const receiverName = serverTransport[0].toLowerCase();
let content = `func (${receiverName} *${serverTransport}) ${dispatchMethodFake}(req *http.Request, method string) (*http.Response, error) {\n`;
content += '\tresultChan := make(chan result)\n';
content += '\tdefer close(resultChan)\n\n';
content += '\tgo func() {\n\t\tvar intercepted bool\n\t\tvar res result\n';
const interceptorVarName = getTransportInterceptorVarName(client);
content += `\t\t if ${interceptorVarName} != nil {\n`;
content += `\t\t\t res.resp, res.err, intercepted = ${interceptorVarName}.Do(req)\n\t\t}\n`;
content += '\t\tif !intercepted {\n';
content += '\t\t\tswitch method {\n';
for (const method of values(finalMethods)) {
const operationName = fixUpMethodName(method);
content += `\t\t\tcase "${client.name}.${operationName}":\n`;
content += `\t\t\t\tres.resp, res.err = ${receiverName}.dispatch${operationName}(req)\n`;
}
content += '\t\t\t\tdefault:\n\t\tres.err = fmt.Errorf("unhandled API %s", method)\n';
content += '\t\t\t}\n\n'; // end switch
content += '\t\t}\n'; // end if !intercepted
content += '\t\tselect {\n';
content += '\t\tcase resultChan <- res:\n';
content += '\t\tcase <-req.Context().Done():\n';
content += '\t\t}\n';
content += '\t}()\n\n'; // end goroutine
content += '\tselect {\n';
content += '\tcase <-req.Context().Done():\n';
content += '\t\treturn nil, req.Context().Err()\n';
content += '\tcase res := <-resultChan:\n';
content += '\t\treturn res.resp, res.err\n';
content += '\t}\n}\n\n';
return content;
}
function generateServerTransportMethods(codeModel, serverTransport, finalMethods, imports) {
if (finalMethods.length === 0) {
return '';
}
imports.add(helpers.getParentImport(codeModel));
imports.add('github.com/Azure/azure-sdk-for-go/sdk/azcore/fake', 'azfake');
imports.add('github.com/Azure/azure-sdk-for-go/sdk/azcore/fake/server');
const receiverName = serverTransport[0].toLowerCase();
let content = '';
for (const method of values(finalMethods)) {
content += `func (${receiverName} *${serverTransport}) dispatch${fixUpMethodName(method)}(req *http.Request) (*http.Response, error) {\n`;
content += `\tif ${receiverName}.srv.${fixUpMethodName(method)} == nil {\n`;
content += `\t\treturn nil, &nonRetriableError{errors.New("fake for method ${fixUpMethodName(method)} not implemented")}\n\t}\n`;
switch (method.kind) {
case 'lroMethod':
case 'lroPageableMethod':
// must check LRO before pager as you can have paged LROs
content += dispatchForLROBody(codeModel.packageName, receiverName, method, imports);
break;
case 'method': {
content += dispatchForOperationBody(codeModel.packageName, receiverName, method, imports);
content += '\trespContent := server.GetResponseContent(respr)\n';
const formattedStatusCodes = helpers.formatStatusCodes(method.httpStatusCodes);
content += `\tif !contains([]int{${formattedStatusCodes}}, respContent.HTTPStatus) {\n`;
content += `\t\treturn nil, &nonRetriableError{fmt.Errorf("unexpected status code %d. acceptable values are ${formattedStatusCodes}", respContent.HTTPStatus)}\n\t}\n`;
if (!method.returns.result || method.returns.result.kind === 'headAsBooleanResult') {
content += '\tresp, err := server.NewResponse(respContent, req, nil)\n';
}
else if (method.returns.result.kind === 'anyResult') {
content += `\tresp, err := server.MarshalResponseAs${method.returns.result.format}(respContent, server.GetResponse(respr).${getResultFieldName(method.returns.result)}, req)\n`;
}
else if (method.returns.result.kind === 'binaryResult') {
content += '\tresp, err := server.NewResponse(respContent, req, &server.ResponseOptions{\n';
content += `\t\tBody: server.GetResponse(respr).${getResultFieldName(method.returns.result)},\n`;
content += '\t\tContentType: req.Header.Get("Content-Type"),\n';
content += '\t})\n';
}
else if (method.returns.result.kind === 'monomorphicResult') {
if (method.returns.result.monomorphicType.kind === 'encodedBytes') {
const encoding = method.returns.result.monomorphicType.encoding;
content += `\tresp, err := server.MarshalResponseAsByteArray(respContent, server.GetResponse(respr).${getResultFieldName(method.returns.result)}, runtime.Base64${encoding}Format, req)\n`;
}
else if (method.returns.result.monomorphicType.kind === 'rawJSON') {
imports.add('bytes');
imports.add('io');
content += '\tresp, err := server.NewResponse(respContent, req, &server.ResponseOptions{\n';
content += '\t\tBody: io.NopCloser(bytes.NewReader(server.GetResponse(respr).RawJSON)),\n';
content += '\t\tContentType: "application/json",\n\t})\n';
}
else {
let respField = `.${getResultFieldName(method.returns.result)}`;
if (method.returns.result.format === 'XML' && method.returns.result.monomorphicType.kind === 'slice') {
// for XML array responses we use the response type directly as it has the necessary XML tag for proper marshalling
respField = '';
}
let responseField = `server.GetResponse(respr)${respField}`;
if (method.returns.result.monomorphicType.kind === 'time') {
responseField = `(*${method.returns.result.monomorphicType.format})(${responseField})`;
}
content += `\tresp, err := server.MarshalResponseAs${method.returns.result.format}(respContent, ${responseField}, req)\n`;
}
}
else if (method.returns.result.kind === 'modelResult' || method.returns.result.kind === 'polymorphicResult') {
const respField = `.${getResultFieldName(method.returns.result)}`;
const responseField = `server.GetResponse(respr)${respField}`;
content += `\tresp, err := server.MarshalResponseAs${method.returns.result.format}(respContent, ${responseField}, req)\n`;
}
content += '\tif err != nil {\n\t\treturn nil, err\n\t}\n';
// propagate any header response values into the *http.Response
for (const header of values(method.returns.headers)) {
if (header.kind === 'headerMapResponse') {
content += `\tfor k, v := range server.GetResponse(respr).${header.fieldName} {\n`;
content += '\t\tif v != nil {\n';
content += `\t\t\tresp.Header.Set("${header.headerName}"+k, *v)\n`;
content += '\t\t}\n';
content += '\t}\n';
}
else {
content += `\tif val := server.GetResponse(respr).${header.fieldName}; val != nil {\n`;
content += `\t\tresp.Header.Set("${header.headerName}", ${helpers.formatValue('val', header.type, imports, true)})\n\t}\n`;
}
}
content += '\treturn resp, nil\n';
break;
}
case 'pageableMethod':
content += dispatchForPagerBody(codeModel.packageName, receiverName, method, imports);
break;
default:
method;
}
content += '}\n\n';
}
return content;
}
function dispatchForOperationBody(clientPkg, receiverName, method, imports) {
const methodParamGroups = helpers.getMethodParamGroups(method);
const numPathParams = values(methodParamGroups.pathParams).where((each) => { return !go.isLiteralParameter(each.style); }).count();
let content = '';
if (numPathParams > 0) {
imports.add('regexp');
content += `\tconst regexStr = \`${createPathParamsRegex(method, methodParamGroups.pathParams)}\`\n`;
content += '\tregex := regexp.MustCompile(regexStr)\n';
content += '\tmatches := regex.FindStringSubmatch(req.URL.EscapedPath())\n';
// the total number of matches is the number of capture groups
// plus the full match. so we add + 1 to include the full match.
content += `\tif len(matches) < ${numPathParams + 1} {\n`;
content += '\t\treturn nil, fmt.Errorf("failed to parse path %s", req.URL.Path)\n\t}\n';
}
const allQueryParams = methodParamGroups.encodedQueryParams.concat(methodParamGroups.unencodedQueryParams);
if (values(allQueryParams).where((each) => { return each.location === 'method' && !go.isLiteralParameter(each.style); }).any()) {
content += '\tqp := req.URL.Query()\n';
}
// note that these are mutually exclusive
const bodyParam = methodParamGroups.bodyParam;
const formBodyParams = methodParamGroups.formBodyParams;
const multipartBodyParams = methodParamGroups.multipartBodyParams;
const partialBodyParams = methodParamGroups.partialBodyParams;
if (bodyParam) {
switch (bodyParam.bodyFormat) {
case 'JSON':
case 'XML':
if (bodyParam && !go.isLiteralParameter(bodyParam.style)) {
imports.add('github.com/Azure/azure-sdk-for-go/sdk/azcore/fake', 'azfake');
switch (bodyParam.type.kind) {
case 'encodedBytes':
content += `\tbody, err := server.UnmarshalRequestAsByteArray(req, runtime.Base64${bodyParam.type.encoding}Format)\n`;
content += '\tif err != nil {\n\t\treturn nil, err\n\t}\n';
break;
case 'interface':
requiredHelpers.readRequestBody = true;
content += '\traw, err := readRequestBody(req)\n';
content += '\tif err != nil {\n\t\treturn nil, err\n\t}\n';
content += `\tbody, err := unmarshal${bodyParam.type.name}(raw)\n`;
content += '\tif err != nil {\n\t\treturn nil, err\n\t}\n';
break;
case 'rawJSON':
imports.add('io');
content += '\tbody, err := io.ReadAll(req.Body)\n';
content += '\tif err != nil {\n\t\treturn nil, err\n\t}\n';
content += '\treq.Body.Close()\n';
break;
default: {
let bodyTypeName = go.getTypeDeclaration(bodyParam.type, clientPkg);
if (bodyParam.type.kind === 'time') {
bodyTypeName = bodyParam.type.format;
}
content += `\tbody, err := server.UnmarshalRequestAs${bodyParam.bodyFormat}[${bodyTypeName}](req)\n`;
content += '\tif err != nil {\n\t\treturn nil, err\n\t}\n';
}
}
}
break;
case 'Text':
if (bodyParam && !go.isLiteralParameter(bodyParam.style)) {
imports.add('github.com/Azure/azure-sdk-for-go/sdk/azcore/fake', 'azfake');
content += '\tbody, err := server.UnmarshalRequestAsText(req)\n';
content += '\tif err != nil {\n\t\treturn nil, err\n\t}\n';
}
break;
}
// nothing to do for binary media type
}
else if (multipartBodyParams.length > 0) {
imports.add('io');
imports.add('mime');
imports.add('mime/multipart');
content += '\t_, params, err := mime.ParseMediaType(req.Header.Get("Content-Type"))\n';
content += '\tif err != nil {\n\t\treturn nil, err\n\t}\n';
content += '\treader := multipart.NewReader(req.Body, params["boundary"])\n';
for (const param of multipartBodyParams) {
let pkgPrefix = '';
switch (param.type.kind) {
case 'constant':
case 'model':
case 'polymorphicModel':
pkgPrefix = clientPkg + '.';
break;
}
content += `\tvar ${param.name} ${pkgPrefix}${go.getTypeDeclaration(param.type)}\n`;
}
content += '\tfor {\n';
content += '\t\tvar part *multipart.Part\n';
content += '\t\tpart, err = reader.NextPart()\n';
content += '\t\tif errors.Is(err, io.EOF) {\n\t\t\tbreak\n';
content += '\t\t} else if err != nil {\n\t\t\treturn nil, err\n\t\t}\n';
content += '\t\tvar content []byte\n';
content += '\t\tswitch fn := part.FormName(); fn {\n';
// specify boolTarget if parsing bools happens in place.
// i.e. the result from the parsing doesn't require further conversion (e.g. casting)
// otherwise the parsed value is in a local var named parsed.
const parsePrimitiveType = function (typeName, boolTarget) {
let parseErr = 'parseErr';
const parseResults = `parsed, ${parseErr}`;
let parsingCode = '';
imports.add('strconv');
switch (typeName) {
case 'bool':
if (boolTarget) {
// we reuse the err var declared earlier when calling reader.NextPart()
parsingCode = `\t\t\t${boolTarget}, err = strconv.ParseBool(string(content))\n`;
parseErr = 'err';
}
else {
parsingCode = `\t\t\t${parseResults} := strconv.ParseBool(string(content))\n`;
}
break;
case 'float32':
case 'float64':
parsingCode = `\t\t\t${parseResults} := strconv.ParseFloat(string(content), ${helpers.getBitSizeForNumber(typeName)})\n`;
break;
case 'int8':
case 'int16':
case 'int32':
case 'int64':
parsingCode = `\t\t\t${parseResults} := strconv.ParseInt(string(content), 10, ${helpers.getBitSizeForNumber(typeName)})\n`;
break;
default:
throw new CodegenError('InternalError', `unhandled multipart parameter primitive type ${typeName}`);
}
parsingCode += `\t\t\tif ${parseErr} != nil {\n\t\t\t\treturn nil, ${parseErr}\n\t\t\t}\n`;
return parsingCode;
};
const isModelType = function (type) {
return type.kind === 'model' || type.kind === 'polymorphicModel';
};
const emitCase = function (caseValue, paramVar, type) {
let caseContent = `\t\tcase "${caseValue}":\n`;
caseContent += '\t\t\tcontent, err = io.ReadAll(part)\n';
caseContent += '\t\t\tif err != nil {\n\t\t\t\treturn nil, err\n\t\t\t}\n';
let assignedValue;
if (isModelType(helpers.recursiveUnwrapMapSlice(type))) {
imports.add('encoding/json');
caseContent += `\t\t\tif err = json.Unmarshal(content, &${paramVar}); err != nil {\n\t\t\t\treturn nil, err\n\t\t\t}\n`;
}
else if (type.kind === 'readSeekCloser') {
imports.add('bytes');
imports.add('github.com/Azure/azure-sdk-for-go/sdk/azcore/streaming');
assignedValue = 'streaming.NopCloser(bytes.NewReader(content))';
}
else if (type.kind === 'constant') {
let from;
switch (type.type) {
case 'bool':
case 'float32':
case 'float64':
case 'int32':
case 'int64':
caseContent += parsePrimitiveType(type.type);
from = 'parsed';
break;
case 'string':
from = 'content';
break;
}
assignedValue = `${clientPkg}.${type.name}(${from})`;
}
else if (type.kind === 'scalar') {
switch (type.type) {
case 'bool':
imports.add('strconv');
// ParseBool happens in place, so no need to set assignedValue
caseContent += parsePrimitiveType(type.type, paramVar);
break;
case 'float32':
case 'float64':
case 'int8':
case 'int16':
case 'int32':
case 'int64':
caseContent += parsePrimitiveType(type.type);
assignedValue = `${type.type}(parsed)`;
break;
default:
throw new CodegenError('InternalError', `unhandled multipart parameter primitive type ${type.type}`);
}
}
else if (type.kind === 'string') {
assignedValue = 'string(content)';
}
else if (helpers.recursiveUnwrapMapSlice(type).kind === 'multipartContent') {
imports.add('bytes');
imports.add('github.com/Azure/azure-sdk-for-go/sdk/azcore/streaming');
const bodyContent = 'streaming.NopCloser(bytes.NewReader(content))';
const contentType = 'part.Header.Get("Content-Type")';
const filename = 'part.FileName()';
if (type.kind === 'slice') {
caseContent += `\t\t\t${paramVar} = append(${paramVar}, streaming.MultipartContent{\n`;
caseContent += `\t\t\t\tBody: ${bodyContent},\n`;
caseContent += `\t\t\t\tContentType: ${contentType},\n`;
caseContent += `\t\t\t\tFilename: ${filename},\n`;
caseContent += '\t\t\t})\n';
}
else {
caseContent += `\t\t\t${paramVar}.Body = ${bodyContent}\n`;
caseContent += `\t\t\t${paramVar}.ContentType = ${contentType}\n`;
caseContent += `\t\t\t${paramVar}.Filename = ${filename}\n`;
}
}
else if (type.kind === 'slice') {
if (type.elementType.kind === 'readSeekCloser') {
imports.add('bytes');
imports.add('github.com/Azure/azure-sdk-for-go/sdk/azcore/streaming');
assignedValue = `append(${paramVar}, streaming.NopCloser(bytes.NewReader(content)))`;
}
else {
throw new CodegenError('InternalError', `uhandled multipart parameter array element kind ${type.elementType.kind}`);
}
}
else {
throw new CodegenError('InternalError', `uhandled multipart parameter kind ${type.kind}`);
}
if (assignedValue) {
caseContent += `\t\t\t${paramVar} = ${assignedValue}\n`;
}
return caseContent;
};
for (const param of multipartBodyParams) {
if (isModelType(param.type)) {
for (const field of param.type.fields) {
content += emitCase(field.serializedName, `${param.name}.${field.name}`, field.type);
}
}
else {
content += emitCase(param.name, param.name, param.type);
}
}
content += '\t\tdefault:\n\t\t\treturn nil, fmt.Errorf("unexpected part %s", fn)\n';
content += '\t\t}\n'; // end switch
content += '\t}\n'; // end for
}
else if (formBodyParams.length > 0) {
for (const param of formBodyParams) {
let pkgPrefix = '';
if (param.type.kind === 'constant') {
pkgPrefix = clientPkg + '.';
}
content += `\tvar ${param.name} ${pkgPrefix}${go.getTypeDeclaration(param.type)}\n`;
}
content += '\tif err := req.ParseForm(); err != nil {\n\t\treturn nil, &nonRetriableError{fmt.Errorf("failed parsing form data: %v", err)}\n\t}\n';
content += '\tfor key := range req.Form {\n';
content += '\t\tswitch key {\n';
for (const param of formBodyParams) {
content += `\t\tcase "${param.formDataName}":\n`;
let assignedValue;
switch (param.type.kind) {
case 'constant':
assignedValue = `${go.getTypeDeclaration(param.type, clientPkg)}(req.FormValue(key))`;
break;
case 'string':
assignedValue = 'req.FormValue(key)';
break;
default:
throw new CodegenError('InternalError', `uhandled form parameter kind ${param.type.kind}`);
}
content += `\t\t\t${param.name} = ${assignedValue}\n`;
}
content += '\t\t}\n'; // end switch
content += '\t}\n'; // end for
}
else if (partialBodyParams.length > 0) {
// construct the partial body params type and unmarshal it
content += '\ttype partialBodyParams struct {\n';
for (const partialBodyParam of partialBodyParams) {
content += `\t\t${capitalize(partialBodyParam.name)} ${helpers.star(partialBodyParam.byValue)}${go.getTypeDeclaration(partialBodyParam.type)} \`json:"${partialBodyParam.serializedName}"\`\n`;
}
content += '\t}\n';
content += `\tbody, err := server.UnmarshalRequestAs${partialBodyParams[0].format}[partialBodyParams](req)\n`;
content += '\tif err != nil {\n\t\treturn nil, err\n\t}\n';
}
const result = parseHeaderPathQueryParams(clientPkg, method, imports);
content += result.content;
// translate each partial body param to its field within the unmarshalled body
for (const partialBodyParam of partialBodyParams) {
result.params.set(partialBodyParam.name, `${helpers.star(partialBodyParam.byValue)}body.${capitalize(partialBodyParam.name)}`);
}
const apiCall = `:= ${receiverName}.srv.${fixUpMethodName(method)}(${populateApiParams(clientPkg, method, result.params, imports)})`;
if (method.kind === 'pageableMethod') {
content += `resp ${apiCall}\n`;
return content;
}
content += `\trespr, errRespr ${apiCall}\n`;
content += '\tif respErr := server.GetError(errRespr, req); respErr != nil {\n';
content += '\t\treturn nil, respErr\n\t}\n';
return content;
}
function getMethodStatusCodes(method) {
// NOTE: don't modify the original array!
const statusCodes = Array.from(method.httpStatusCodes);
switch (method.kind) {
case 'lroMethod':
case 'lroPageableMethod':
if (!statusCodes.includes(200)) {
// pollers always include 200 as an acceptible status code so we emulate that here
statusCodes.unshift(200);
}
if (!method.returns.result && !statusCodes.includes(204)) {
// also include 204 if the LRO doesn't return a body
statusCodes.push(204);
}
break;
}
return statusCodes;
}
function dispatchForLROBody(clientPkg, receiverName, method, imports) {
const operationName = fixUpMethodName(method);
const localVarName = uncapitalize(operationName);
const operationStateMachine = `${receiverName}.${uncapitalize(operationName)}`;
let content = `\t${localVarName} := ${operationStateMachine}.get(req)\n`;
content += `\tif ${localVarName} == nil {\n`;
content += dispatchForOperationBody(clientPkg, receiverName, method, imports);
content += `\t\t${localVarName} = &respr\n`;
content += `\t\t${operationStateMachine}.add(req, ${localVarName})\n`;
content += '\t}\n\n';
content += `\tresp, err := server.PollerResponderNext(${localVarName}, req)\n`;
content += '\tif err != nil {\n\t\treturn nil, err\n\t}\n\n';
const formattedStatusCodes = helpers.formatStatusCodes(getMethodStatusCodes(method));
content += `\tif !contains([]int{${formattedStatusCodes}}, resp.StatusCode) {\n`;
content += `\t\t${operationStateMachine}.remove(req)\n`;
content += `\t\treturn nil, &nonRetriableError{fmt.Errorf("unexpected status code %d. acceptable values are ${formattedStatusCodes}", resp.StatusCode)}\n\t}\n`;
content += `\tif !server.PollerResponderMore(${localVarName}) {\n`;
content += `\t\t${operationStateMachine}.remove(req)\n\t}\n\n`;
content += '\treturn resp, nil\n';
return content;
}
function dispatchForPagerBody(clientPkg, receiverName, method, imports) {
const operationName = fixUpMethodName(method);
const localVarName = uncapitalize(operationName);
const operationStateMachine = `${receiverName}.${uncapitalize(operationName)}`;
let content = `\t${localVarName} := ${operationStateMachine}.get(req)\n`;
content += `\tif ${localVarName} == nil {\n`;
content += dispatchForOperationBody(clientPkg, receiverName, method, imports);
content += `\t\t${localVarName} = &resp\n`;
content += `\t\t${operationStateMachine}.add(req, ${localVarName})\n`;
if (method.nextLinkName) {
imports.add('github.com/Azure/azure-sdk-for-go/sdk/azcore/to');
content += `\t\tserver.PagerResponderInjectNextLinks(${localVarName}, req, func(page *${clientPkg}.${method.returns.name}, createLink func() string) {\n`;
content += `\t\t\tpage.${method.nextLinkName} = to.Ptr(createLink())\n`;
content += '\t\t})\n';
}
content += '\t}\n'; // end if
content += `\tresp, err := server.PagerResponderNext(${localVarName}, req)\n`;
content += '\tif err != nil {\n\t\treturn nil, err\n\t}\n';
const formattedStatusCodes = helpers.formatStatusCodes(method.httpStatusCodes);
content += `\tif !contains([]int{${formattedStatusCodes}}, resp.StatusCode) {\n`;
content += `\t\t${operationStateMachine}.remove(req)\n`;
content += `\t\treturn nil, &nonRetriableError{fmt.Errorf("unexpected status code %d. acceptable values are ${formattedStatusCodes}", resp.StatusCode)}\n\t}\n`;
content += `\tif !server.PagerResponderMore(${localVarName}) {\n`;
content += `\t\t${operationStateMachine}.remove(req)\n\t}\n`;
content += '\treturn resp, nil\n';
return content;
}
function sanitizeRegexpCaptureGroupName(name) {
// dash '-' characters are not allowed so replace them with '_'
return name.replace('-', '_');
}
function createPathParamsRegex(method, pathParams) {
// "/subscriptions/{subscriptionId}/resourcegroups/{resourceGroupName}/providers/{resourceProviderNamespace}/{parentResourcePath}/{resourceType}/{resourceName}"
// each path param will replaced with a regex capture.
// note that some path params are optional.
let urlPath = method.httpPath;
// escape any characters in the path that could be interpreted as regex tokens
// per RFC3986, these are the pchars that also double as regex tokens
// . $ * + ()
urlPath = urlPath.replace(/([.$*+()])/g, '\\$1');
for (const param of pathParams) {
const toReplace = `{${param.pathSegment}}`;
let replaceWith = `(?P<${sanitizeRegexpCaptureGroupName(param.pathSegment)}>[!#&$-;=?-\\[\\]_a-zA-Z0-9~%@]+)`;
if (param.style === 'optional' || param.style === 'flag') {
replaceWith += '?';
}
urlPath = urlPath.replace(toReplace, replaceWith);
}
return urlPath;
}
// parses header/path/query params as required.
// returns the parsing code and the params that contain the parsed values.
function parseHeaderPathQueryParams(clientPkg, method, imports) {
let content = '';
const paramValues = new Map();
const createLocalVariableName = function (param, suffix) {
const paramName = `${uncapitalize(param.name)}${suffix}`;
paramValues.set(param.name, paramName);
return paramName;
};
const emitNumericConversion = function (src, type) {
imports.add('strconv');
let precision = '32';
if (type === 'float64' || type === 'int64') {
precision = '64';
}
let parseType = 'Int';
let base = '10, ';
if (type === 'float32' || type === 'float64') {
parseType = 'Float';
base = '';
}
return `strconv.Parse${parseType}(${src}, ${base}${precision})`;
};
// track the param groups that need to be instantiated/populated.
// we track the params separately as it might be a subset of ParameterGroup.params
const paramGroups = new Map();
for (const param of values(consolidateHostParams(method.parameters))) {
if (param.location === 'client' || go.isLiteralParameter(param.style)) {
// client params and parameter literals aren't passed to APIs
continue;
}
if (param.kind === 'resumeTokenParam') {
// skip the ResumeToken param as we don't send that back to the caller
continue;
}
// NOTE: param group check must happen before skipping body params.
// this is to handle the case where the body param is grouped/optional
if (param.group) {
let params = paramGroups.get(param.group);
if (!params) {
params = new Array();
paramGroups.set(param.group, params);
}
params.push(param);
}
switch (param.kind) {
case 'bodyParam':
case 'formBodyCollectionParam':
case 'formBodyScalarParam':
case 'multipartFormBodyParam':
case 'partialBodyParam':
// body params will be unmarshalled, no need for parsing.
continue;
}
// paramValue is initialized with the "raw" source value.
// e.g. getHeaderValue(...), qp.Get("foo") etc
// since path/query params need to be unescaped, the value
// of paramValue will be updated with the var name that
// contains the unescaped value.
let paramValue = getRawParamValue(param);
// path/query params might be escaped, so we need to unescape them first.
// must handle query collections first as it's a superset of query param.
if (param.kind === 'queryCollectionParam' && param.collectionFormat === 'multi') {
imports.add('net/url');
const escapedParam = createLocalVariableName(param, 'Escaped');
content += `\t${escapedParam} := ${paramValue}\n`;
let paramVar = createLocalVariableName(param, 'Unescaped');
if (param.type.elementType.kind === 'string') {
// by convention, if the value is in its "final form" (i.e. no parsing required)
// then its var is to have the "Param" suffix. the only case is string, everything
// else requires some amount of parsing/conversion.
paramVar = createLocalVariableName(param, 'Param');
}
content += `\t${paramVar} := make([]string, len(${escapedParam}))\n`;
content += `\tfor i, v := range ${escapedParam} {\n`;
content += '\t\tu, unescapeErr := url.QueryUnescape(v)\n';
content += '\t\tif unescapeErr != nil {\n\t\t\treturn nil, unescapeErr\n\t\t}\n';
content += `\t\t${paramVar}[i] = u\n\t}\n`;
paramValue = paramVar;
}
else if (go.isPathParameter(param) || go.isQueryParameter(param)) {
imports.add('net/url');
let where;
if (go.isPathParameter(param)) {
where = 'Path';
}
else {
where = 'Query';
}
let paramVar = createLocalVariableName(param, 'Unescaped');
if (go.isRequiredParameter(param.style) && param.type.kind === 'constant' && param.type.type === 'string') {
// for string-based enums, we perform the conversion as part of unescaping
requiredHelpers.parseWithCast = true;
paramVar = createLocalVariableName(param, 'Param');
content += `\t${paramVar}, err := parseWithCast(${paramValue}, func (v string) (${go.getTypeDeclaration(param.type, clientPkg)}, error) {\n`;
content += `\t\tp, unescapeErr := url.${where}Unescape(v)\n`;
content += '\t\tif unescapeErr != nil {\n\t\t\treturn "", unescapeErr\n\t\t}\n';
content += `\t\treturn ${go.getTypeDeclaration(param.type, clientPkg)}(p), nil\n\t})\n`;
}
else {
if (go.isRequiredParameter(param.style) &&
(param.type.kind === 'string' || (param.type.kind === 'slice' && param.type.elementType.kind === 'string'))) {
// by convention, if the value is in its "final form" (i.e. no parsing required)
// then its var is to have the "Param" suffix. the only case is string, everything
// else requires some amount of parsing/conversion.
paramVar = createLocalVariableName(param, 'Param');
}
content += `\t${paramVar}, err := url.${where}Unescape(${paramValue})\n`;
}
content += '\tif err != nil {\n\t\treturn nil, err\n\t}\n';
paramValue = paramVar;
}
// parse params as required
if (param.kind === 'headerCollectionParam' || param.kind === 'pathCollectionParam' || param.kind === 'queryCollectionParam') {
// any element type other than string will require some form of conversion/parsing
if (param.type.elementType.kind !== 'string') {
if (param.collectionFormat !== 'multi') {
requiredHelpers.splitHelper = true;
const elementsParam = createLocalVariableName(param, 'Elements');
content += `\t${elementsParam} := splitHelper(${paramValue}, "${helpers.getDelimiterForCollectionFormat(param.collectionFormat)}")\n`;
paramValue = elementsParam;
}
const paramVar = createLocalVariableName(param, 'Param');
let elementFormat;
switch (param.type.elementType.kind) {
case 'constant':
case 'scalar':
elementFormat = param.type.elementType.type;
break;
case 'encodedBytes':
elementFormat = param.type.elementType.encoding;
break;
case 'time':
elementFormat = param.type.elementType.format;
break;
default:
throw new CodegenEr