@autorest/go
Version:
AutoRest Go Generator
915 lines • 45 kB
JavaScript
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as go from '../../codemodel.go/src/index.js';
import { capitalize, comment } from '@azure-tools/codegen';
import { values } from '@azure-tools/linq';
import * as helpers from './helpers.js';
import { ImportManager } from './imports.js';
// Creates the content in models.go
export async function generateModels(codeModel) {
if (codeModel.models.length === 0) {
return {
models: '',
serDe: ''
};
}
// this list of packages to import
const modelImports = new ImportManager();
const serdeImports = new ImportManager();
let modelText = helpers.contentPreamble(codeModel);
// we do model generation first as it can add imports to the imports list
const modelDefs = generateModelDefs(modelImports, serdeImports, codeModel);
modelText += modelImports.text();
// structs
let needsJSONPopulate = false;
let needsJSONUnpopulate = false;
let needsJSONPopulateByteArray = false;
let needsJSONPopulateAny = false;
let needsJSONPopulateMultipart = false;
let serdeTextBody = '';
for (const modelDef of values(modelDefs)) {
modelText += modelDef.text();
modelDef.Methods.sort((a, b) => { return helpers.sortAscending(a.name, b.name); });
for (const method of values(modelDef.Methods)) {
if (method.desc.length > 0) {
modelText += `${comment(method.desc, '// ', undefined, helpers.commentLength)}\n`;
}
modelText += method.text;
}
modelDef.SerDe.methods.sort((a, b) => { return helpers.sortAscending(a.name, b.name); });
for (const method of values(modelDef.SerDe.methods)) {
if (method.desc.length > 0) {
serdeTextBody += `${comment(method.desc, '// ', undefined, helpers.commentLength)}\n`;
}
serdeTextBody += method.text;
}
if (modelDef.SerDe.needsJSONPopulate) {
needsJSONPopulate = true;
}
if (modelDef.SerDe.needsJSONUnpopulate) {
needsJSONUnpopulate = true;
}
if (modelDef.SerDe.needsJSONPopulateByteArray) {
needsJSONPopulateByteArray = true;
}
if (modelDef.SerDe.needsJSONPopulateAny) {
needsJSONPopulateAny = true;
}
if (modelDef.SerDe.needsJSONPopulateMultipart) {
needsJSONPopulateMultipart = true;
}
}
if (needsJSONPopulate) {
serdeImports.add('reflect');
serdeImports.add('github.com/Azure/azure-sdk-for-go/sdk/azcore');
serdeTextBody += 'func populate(m map[string]any, k string, v any) {\n';
serdeTextBody += '\tif v == nil {\n';
serdeTextBody += '\t\treturn\n';
serdeTextBody += '\t} else if azcore.IsNullValue(v) {\n';
serdeTextBody += '\t\tm[k] = nil\n';
serdeTextBody += '\t} else if !reflect.ValueOf(v).IsNil() {\n';
serdeTextBody += '\t\tm[k] = v\n';
serdeTextBody += '\t}\n';
serdeTextBody += '}\n\n';
}
if (needsJSONPopulateAny) {
serdeImports.add('github.com/Azure/azure-sdk-for-go/sdk/azcore');
serdeTextBody += 'func populateAny(m map[string]any, k string, v any) {\n';
serdeTextBody += '\tif v == nil {\n';
serdeTextBody += '\t\treturn\n';
serdeTextBody += '\t} else if azcore.IsNullValue(v) {\n';
serdeTextBody += '\t\tm[k] = nil\n';
serdeTextBody += '\t} else {\n';
serdeTextBody += '\t\tm[k] = v\n';
serdeTextBody += '\t}\n';
serdeTextBody += '}\n\n';
}
if (needsJSONPopulateByteArray) {
serdeImports.add('github.com/Azure/azure-sdk-for-go/sdk/azcore');
serdeImports.add('github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime');
serdeTextBody += 'func populateByteArray[T any](m map[string]any, k string, b []T, convert func() any) {\n';
serdeTextBody += '\tif azcore.IsNullValue(b) {\n';
serdeTextBody += '\t\tm[k] = nil\n';
serdeTextBody += '\t} else if len(b) == 0 {\n';
serdeTextBody += '\t\treturn\n';
serdeTextBody += '\t} else {\n';
serdeTextBody += '\t\tm[k] = convert()\n';
serdeTextBody += '\t}\n';
serdeTextBody += '}\n\n';
}
if (needsJSONUnpopulate) {
serdeImports.add('fmt');
serdeTextBody += 'func unpopulate(data json.RawMessage, fn string, v any) error {\n';
serdeTextBody += '\tif data == nil || string(data) == "null" {\n';
serdeTextBody += '\t\treturn nil\n';
serdeTextBody += '\t}\n';
serdeTextBody += '\tif err := json.Unmarshal(data, v); err != nil {\n';
serdeTextBody += '\t\treturn fmt.Errorf("struct field %s: %v", fn, err)\n';
serdeTextBody += '\t}\n';
serdeTextBody += '\treturn nil\n';
serdeTextBody += '}\n\n';
}
if (needsJSONPopulateMultipart) {
serdeImports.add('encoding/json');
serdeTextBody += 'func populateMultipartJSON(m map[string]any, k string, v any) error {\n';
serdeTextBody += '\tdata, err := json.Marshal(v)\n';
serdeTextBody += '\tif err != nil {\n\t\treturn err\n\t}\n';
serdeTextBody += '\tm[k] = data\n\treturn nil\n';
serdeTextBody += '}\n\n';
}
let serdeText = '';
if (serdeTextBody.length > 0) {
serdeText = helpers.contentPreamble(codeModel);
serdeText += serdeImports.text();
serdeText += serdeTextBody;
}
return {
models: modelText,
serDe: serdeText
};
}
function generateModelDefs(modelImports, serdeImports, codeModel) {
const models = codeModel.models;
const modelDefs = new Array();
for (const model of models) {
for (const field of model.fields) {
const descriptionMods = new Array();
if (field.annotations.readOnly) {
descriptionMods.push('READ-ONLY');
}
else if (field.annotations.required && (field.type.kind !== 'literal' || model.usage === go.UsageFlags.Output)) {
descriptionMods.push('REQUIRED');
}
else if (field.type.kind === 'literal') {
if (!field.annotations.required) {
descriptionMods.push('FLAG');
}
descriptionMods.push('CONSTANT');
}
if (field.type.kind === 'literal' && model.usage !== go.UsageFlags.Output) {
// add a comment with the const value for const properties that are sent over the wire
if (field.docs.description) {
field.docs.description += '\n';
}
field.docs.description += `Field has constant value ${helpers.formatLiteralValue(field.type, false)}, any specified value is ignored.`;
}
if (field.docs.description) {
descriptionMods.push(field.docs.description);
}
else if (field.type.kind === 'rawJSON') {
// add a basic description if one isn't available
descriptionMods.push('The contents of this field are raw JSON.');
}
field.docs.description = descriptionMods.join('; ');
}
const serDeFormat = helpers.getSerDeFormat(model, codeModel);
const modelDef = new ModelDef(model.name, serDeFormat, model.fields, model.docs);
for (const field of values(modelDef.Fields)) {
modelImports.addImportForType(field.type);
}
if (model.kind === 'model' && serDeFormat === 'XML' && !model.annotations.omitSerDeMethods) {
serdeImports.add('encoding/xml');
let needsDateTimeMarshalling = false;
let byteArrayFormat = false;
for (const field of values(model.fields)) {
serdeImports.addImportForType(field.type);
if (field.type.kind === 'time') {
needsDateTimeMarshalling = true;
}
else if (field.type.kind === 'encodedBytes') {
byteArrayFormat = true;
}
}
// due to differences in XML marshallers/unmarshallers, we use different codegen than for JSON
if (needsDateTimeMarshalling || model.xml?.wrapper || needsXMLArrayMarshalling(model) || byteArrayFormat) {
generateXMLMarshaller(model, modelDef, serdeImports);
if (needsDateTimeMarshalling || byteArrayFormat) {
generateXMLUnmarshaller(model, modelDef, serdeImports);
}
}
else if (needsXMLDictionaryHelper(model)) {
generateXMLMarshaller(model, modelDef, serdeImports);
generateXMLUnmarshaller(model, modelDef, serdeImports);
}
modelDefs.push(modelDef);
continue;
}
if (model.kind === 'polymorphicModel') {
generateDiscriminatorMarkerMethod(model.interface, modelDef);
for (let parent = model.interface.parent; parent !== undefined; parent = parent.parent) {
generateDiscriminatorMarkerMethod(parent, modelDef);
}
}
if (model.annotations.multipartFormData) {
generateToMultipartForm(modelDef);
modelDef.SerDe.needsJSONPopulateMultipart = true;
}
else if (!model.annotations.omitSerDeMethods) {
generateJSONMarshaller(model, modelDef, serdeImports);
generateJSONUnmarshaller(model, modelDef, serdeImports, codeModel.options);
}
modelDefs.push(modelDef);
}
return modelDefs;
}
function needsXMLDictionaryHelper(modelType) {
for (const field of values(modelType.fields)) {
// additional properties uses an internal wrapper type with its own serde impl
if (field.type.kind === 'map' && !field.annotations.isAdditionalProperties) {
return true;
}
}
return false;
}
function needsXMLArrayMarshalling(modelType) {
for (const prop of values(modelType.fields)) {
if (prop.type.kind === 'slice') {
return true;
}
}
return false;
}
// generates discriminator marker method
function generateDiscriminatorMarkerMethod(type, modelDef) {
const typeName = type.rootType.name;
const receiver = modelDef.receiverName();
const interfaceMethod = `Get${typeName}`;
let method = `func (${receiver} *${modelDef.Name}) ${interfaceMethod}() *${typeName} {`;
if (type.rootType.name === modelDef.Name) {
// the marker method is on the discriminator itself, so just return the receiver
method += ` return ${receiver} }\n\n`;
}
else {
// the marker method is on a child type, so return an instance of the parent
// type by copying the parent values into a new instance.
method += `\n\treturn &${type.rootType.name}{\n`;
for (const field of values(type.rootType.fields)) {
method += `\t\t${field.name}: ${modelDef.receiverName()}.${field.name},\n`;
}
method += '\t}\n}\n\n';
}
modelDef.Methods.push({ name: interfaceMethod, desc: `${interfaceMethod} implements the ${type.name} interface for type ${modelDef.Name}.`, text: method });
}
function generateToMultipartForm(modelDef) {
const receiver = modelDef.receiverName();
let method = `func (${receiver} ${modelDef.Name}) toMultipartFormData() (map[string]any, error) {\n`;
method += '\tobjectMap := make(map[string]any)\n';
for (const field of modelDef.Fields) {
const fieldType = helpers.recursiveUnwrapMapSlice(field.type);
let setField;
let star = '';
if (!field.byValue) {
star = '*';
}
if (fieldType.kind === 'model' && !fieldType.annotations.multipartFormData) {
setField = `\tif err := populateMultipartJSON(objectMap, "${field.serializedName}", ${star}${receiver}.${field.name}); err != nil {\n\t\treturn nil, err\n\t}\n`;
}
else {
setField = `\tobjectMap["${field.serializedName}"] = ${star}${receiver}.${field.name}\n`;
}
if (!field.byValue) {
setField = `\tif ${receiver}.${field.name} != nil {\n\t\t${setField}\t}\n`;
}
method += setField;
}
method += '\treturn objectMap, nil\n}\n\n';
modelDef.SerDe.methods.push({ name: 'toMultipartFormData', desc: `toMultipartFormData converts ${modelDef.Name} to multipart/form data.`, text: method });
}
function generateJSONMarshaller(modelType, modelDef, imports) {
if (modelType.kind === 'model' && modelType.fields.length === 0) {
// non-discriminated types without content don't need a custom marshaller.
// there is a case in network where child is allOf base and child has no properties.
return;
}
imports.add('encoding/json');
const typeName = modelDef.Name;
const receiver = modelDef.receiverName();
let marshaller = `func (${receiver} ${typeName}) MarshalJSON() ([]byte, error) {\n`;
marshaller += '\tobjectMap := make(map[string]any)\n';
marshaller += generateJSONMarshallerBody(modelType, modelDef, receiver, imports);
marshaller += '\treturn json.Marshal(objectMap)\n';
marshaller += '}\n\n';
modelDef.SerDe.methods.push({ name: 'MarshalJSON', desc: `MarshalJSON implements the json.Marshaller interface for type ${typeName}.`, text: marshaller });
}
function generateJSONMarshallerBody(modelType, modelDef, receiver, imports) {
let marshaller = '';
let addlProps;
for (const field of values(modelType.fields)) {
if (field.type.kind === 'map' && field.annotations.isAdditionalProperties) {
addlProps = field.type;
continue;
}
if (field.annotations.isDiscriminator) {
if (field.defaultValue) {
marshaller += `\tobjectMap["${field.serializedName}"] = ${helpers.formatLiteralValue(field.defaultValue, true)}\n`;
}
else {
// if there's no discriminator value (e.g. Fish in test server), use the field's value.
// this will enable support for custom types that aren't (yet) described in the swagger.
marshaller += `\tobjectMap["${field.serializedName}"] = ${receiver}.${field.name}\n`;
}
}
else if (field.type.kind === 'encodedBytes') {
imports.add('github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime');
marshaller += `\tpopulateByteArray(objectMap, "${field.serializedName}", ${receiver}.${field.name}, func() any {\n`;
marshaller += `\t\treturn runtime.EncodeByteArray(${receiver}.${field.name}, runtime.Base64${field.type.encoding}Format)\n\t})\n`;
modelDef.SerDe.needsJSONPopulateByteArray = true;
}
else if (field.type.kind === 'slice' && field.type.elementType.kind === 'encodedBytes') {
imports.add('github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime');
marshaller += `\tpopulateByteArray(objectMap, "${field.serializedName}", ${receiver}.${field.name}, func() any {\n`;
marshaller += `\t\tencodedValue := make([]string, len(${receiver}.${field.name}))\n`;
marshaller += `\t\tfor i := 0; i < len(${receiver}.${field.name}); i++ {\n`;
marshaller += `\t\t\tencodedValue[i] = runtime.EncodeByteArray(${receiver}.${field.name}[i], runtime.Base64${field.type.elementType.encoding}Format)\n\t\t}\n`;
marshaller += '\t\treturn encodedValue\n\t})\n';
modelDef.SerDe.needsJSONPopulateByteArray = true;
}
else if (field.type.kind === 'slice' && field.type.elementType.kind === 'time') {
const source = `${receiver}.${field.name}`;
let elementPtr = '*';
if (field.type.elementTypeByValue) {
elementPtr = '';
}
marshaller += `\taux := make([]${elementPtr}${field.type.elementType.format}, len(${source}), len(${source}))\n`;
marshaller += `\tfor i := 0; i < len(${source}); i++ {\n`;
marshaller += `\t\taux[i] = (${elementPtr}${field.type.elementType.format})(${source}[i])\n`;
marshaller += '\t}\n';
marshaller += `\tpopulate(objectMap, "${field.serializedName}", aux)\n`;
modelDef.SerDe.needsJSONPopulate = true;
}
else if (field.type.kind === 'literal') {
const setter = `objectMap["${field.serializedName}"] = ${helpers.formatLiteralValue(field.type, true)}`;
if (!field.annotations.required) {
marshaller += `\tif ${receiver}.${field.name} != nil {\n\t\t${setter}\n\t}\n`;
}
else {
marshaller += `\t${setter}\n`;
}
}
else if (field.type.kind === 'rawJSON') {
marshaller += `\tpopulate(objectMap, "${field.serializedName}", json.RawMessage(${receiver}.${field.name}))\n`;
modelDef.SerDe.needsJSONPopulate = true;
}
else {
if (field.defaultValue) {
imports.add('github.com/Azure/azure-sdk-for-go/sdk/azcore/to');
marshaller += `\tif ${receiver}.${field.name} == nil {\n\t\t${receiver}.${field.name} = to.Ptr(${helpers.formatLiteralValue(field.defaultValue, true)})\n\t}\n`;
}
let populate = 'populate';
if (field.type.kind === 'time') {
populate += capitalize(field.type.format);
modelDef.SerDe.needsJSONPopulate = true;
}
else if (field.type.kind === 'any') {
populate += 'Any';
modelDef.SerDe.needsJSONPopulateAny = true;
}
else {
modelDef.SerDe.needsJSONPopulate = true;
}
if (field.type.kind === 'scalar' && (field.type.type.startsWith('uint') || field.type.type.startsWith('int')) && field.type.encodeAsString) {
// TODO: need to handle map and slice type with underlying int as string type
imports.add('strconv');
imports.add('github.com/Azure/azure-sdk-for-go/sdk/azcore/to');
if (field.type.type.startsWith('uint') && field.type.type !== 'uint64' || field.type.type.startsWith('int') && field.type.type !== 'int64') {
marshaller += `\t${populate}(objectMap, "${field.serializedName}", to.Ptr(strconv.${field.type.type.startsWith('int') ? 'FormatInt' : 'FormatUint'}(${field.type.type.startsWith('int') ? 'int64' : 'uint64'}(*${receiver}.${field.name}), 10)))\n`;
}
else {
marshaller += `\t${populate}(objectMap, "${field.serializedName}", to.Ptr(strconv.${field.type.type.startsWith('int') ? 'FormatInt' : 'FormatUint'}(*${receiver}.${field.name}, 10)))\n`;
}
}
else {
marshaller += `\t${populate}(objectMap, "${field.serializedName}", ${receiver}.${field.name})\n`;
}
}
}
if (addlProps) {
marshaller += `\tif ${receiver}.AdditionalProperties != nil {\n`;
marshaller += `\t\tfor key, val := range ${receiver}.AdditionalProperties {\n`;
let assignment = 'val';
if (addlProps.valueType.kind === 'time') {
assignment = `(*${addlProps.valueType.format})(val)`;
}
marshaller += `\t\t\tobjectMap[key] = ${assignment}\n`;
marshaller += '\t\t}\n';
marshaller += '\t}\n';
}
return marshaller;
}
function generateJSONUnmarshaller(modelType, modelDef, imports, options) {
// there's a corner-case where a derived type might not add any new fields (Cookiecuttershark).
// in this case skip adding the unmarshaller as it's not necessary and doesn't compile.
if (modelDef.Fields.length === 0) {
return;
}
imports.add('encoding/json');
imports.add('fmt');
const typeName = modelDef.Name;
const receiver = modelDef.receiverName();
let unmarshaller = `func (${receiver} *${typeName}) UnmarshalJSON(data []byte) error {\n`;
unmarshaller += '\tvar rawMsg map[string]json.RawMessage\n';
unmarshaller += '\tif err := json.Unmarshal(data, &rawMsg); err != nil {\n';
unmarshaller += `\t\treturn fmt.Errorf("unmarshalling type %T: %v", ${receiver}, err)\n`;
unmarshaller += '\t}\n';
unmarshaller += generateJSONUnmarshallerBody(modelType, modelDef, receiver, imports, options);
unmarshaller += '}\n\n';
modelDef.SerDe.methods.push({ name: 'UnmarshalJSON', desc: `UnmarshalJSON implements the json.Unmarshaller interface for type ${typeName}.`, text: unmarshaller });
}
function generateJSONUnmarshallerBody(modelType, modelDef, receiver, imports, options) {
// we almost always need to have an error check when unmarshaling the values.
// however, fields that are raw JSON don't require any unmarshaling. so, if all
// of the fields in a type are raw JSON, then the error check isn't necessary
// and can be elided (the linter complains about it otherwise).
let needsErrCheck = false;
const emitAddlProps = function (tab, addlProps) {
let addlPropsText = `${tab}\t\tif ${receiver}.AdditionalProperties == nil {\n`;
let ref = '';
if (!addlProps.valueTypeByValue) {
ref = '&';
}
addlPropsText += `${tab}\t\t\t${receiver}.AdditionalProperties = ${go.getTypeDeclaration(addlProps)}{}\n`;
addlPropsText += `${tab}\t\t}\n`;
addlPropsText += `${tab}\t\tif val != nil {\n`;
let auxType = go.getTypeDeclaration(addlProps.valueType);
let assignment = `${ref}aux`;
if (addlProps.valueType.kind === 'time') {
imports.add('time');
auxType = addlProps.valueType.format;
assignment = `(*time.Time)(${assignment})`;
}
addlPropsText += `${tab}\t\t\tvar aux ${auxType}\n`;
addlPropsText += `${tab}\t\t\terr = json.Unmarshal(val, &aux)\n`;
addlPropsText += `${tab}\t\t\t${receiver}.AdditionalProperties[key] = ${assignment}\n`;
addlPropsText += `${tab}\t\t}\n`;
addlPropsText += `${tab}\t\tdelete(rawMsg, key)\n`;
needsErrCheck = true;
return addlPropsText;
};
const emitSwitchCase = function () {
let unmarshalBody = '';
let addlProps;
unmarshalBody += '\t\tswitch key {\n';
for (const field of values(modelType.fields)) {
if (field.type.kind === 'map' && field.annotations.isAdditionalProperties) {
addlProps = field.type;
continue;
}
unmarshalBody += `\t\tcase "${field.serializedName}":\n`;
if (hasDiscriminatorInterface(field.type)) {
unmarshalBody += generateDiscriminatorUnmarshaller(field, receiver);
needsErrCheck = true;
}
else if (field.type.kind === 'time') {
unmarshalBody += `\t\t\t\terr = unpopulate${capitalize(field.type.format)}(val, "${field.name}", &${receiver}.${field.name})\n`;
modelDef.SerDe.needsJSONUnpopulate = true;
needsErrCheck = true;
}
else if (field.type.kind === 'slice' && field.type.elementType.kind === 'time') {
imports.add('time');
let elementPtr = '*';
if (field.type.elementTypeByValue) {
elementPtr = '';
}
unmarshalBody += `\t\t\tvar aux []${elementPtr}${field.type.elementType.format}\n`;
unmarshalBody += `\t\t\terr = unpopulate(val, "${field.name}", &aux)\n`;
unmarshalBody += '\t\t\tfor _, au := range aux {\n';
unmarshalBody += `\t\t\t\t${receiver}.${field.name} = append(${receiver}.${field.name}, (${elementPtr}time.Time)(au))\n`;
unmarshalBody += '\t\t\t}\n';
modelDef.SerDe.needsJSONUnpopulate = true;
needsErrCheck = true;
}
else if (field.type.kind === 'encodedBytes') {
imports.add('github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime');
unmarshalBody += '\t\tif val != nil && string(val) != "null" {\n';
unmarshalBody += `\t\t\t\terr = runtime.DecodeByteArray(string(val), &${receiver}.${field.name}, runtime.Base64${field.type.encoding}Format)\n\t\t}\n`;
needsErrCheck = true;
}
else if (field.type.kind === 'slice' && field.type.elementType.kind === 'encodedBytes') {
imports.add('github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime');
unmarshalBody += '\t\t\tvar encodedValue []string\n';
unmarshalBody += `\t\t\terr = unpopulate(val, "${field.name}", &encodedValue)\n`;
unmarshalBody += '\t\t\tif err == nil && len(encodedValue) > 0 {\n';
unmarshalBody += `\t\t\t\t${receiver}.${field.name} = make([][]byte, len(encodedValue))\n`;
unmarshalBody += '\t\t\t\tfor i := 0; i < len(encodedValue) && err == nil; i++ {\n';
unmarshalBody += `\t\t\t\t\terr = runtime.DecodeByteArray(encodedValue[i], &${receiver}.${field.name}[i], runtime.Base64${field.type.elementType.encoding}Format)\n`;
unmarshalBody += '\t\t\t\t}\n\t\t\t}\n';
modelDef.SerDe.needsJSONUnpopulate = true;
needsErrCheck = true;
}
else if (field.type.kind === 'rawJSON') {
unmarshalBody += '\t\t\tif string(val) != "null" {\n';
unmarshalBody += `\t\t\t\t${receiver}.${field.name} = val\n\t\t\t}\n`;
}
else if (field.type.kind === 'scalar' && (field.type.type.startsWith('uint') || field.type.type.startsWith('int')) && field.type.encodeAsString) {
// TODO: need to handle map and slice type with underlying int as string type
imports.add('strconv');
unmarshalBody += `\t\t\t\tvar aux string\n`;
unmarshalBody += `\t\t\t\terr = unpopulate(val, "${field.name}", &aux)\n`;
unmarshalBody += `\t\t\t\tif err == nil {\n`;
unmarshalBody += `\t\t\t\t\tvar v ${field.type.type.startsWith('int') ? 'int64' : 'uint64'}\n`;
unmarshalBody += `\t\t\t\t\tv, err = strconv.${field.type.type.startsWith('int') ? 'ParseInt' : 'ParseUint'}(aux, 10, 0)\n`;
unmarshalBody += `\t\t\t\t\tif err == nil {\n`;
if (field.type.type.startsWith('uint') && field.type.type !== 'uint64' || field.type.type.startsWith('int') && field.type.type !== 'int64') {
unmarshalBody += `\t\t\t\t\t\t${receiver}.${field.name} = to.Ptr(${field.type.type}(v))\n`;
}
else {
unmarshalBody += `\t\t\t\t\t\t${receiver}.${field.name} = to.Ptr(v)\n`;
}
unmarshalBody += '\t\t\t\t\t}\n';
unmarshalBody += '\t\t\t\t}\n';
modelDef.SerDe.needsJSONUnpopulate = true;
needsErrCheck = true;
}
else {
unmarshalBody += `\t\t\t\terr = unpopulate(val, "${field.name}", &${receiver}.${field.name})\n`;
modelDef.SerDe.needsJSONUnpopulate = true;
needsErrCheck = true;
}
unmarshalBody += '\t\t\tdelete(rawMsg, key)\n';
}
if (addlProps) {
unmarshalBody += '\t\tdefault:\n';
unmarshalBody += emitAddlProps('\t', addlProps);
}
else if (options.disallowUnknownFields) {
unmarshalBody += '\t\tdefault:\n';
unmarshalBody += `\t\t\terr = fmt.Errorf("unmarshalling type %T, unknown field %q", ${receiver}, key)\n`;
needsErrCheck = true;
}
unmarshalBody += '\t\t}\n';
return unmarshalBody;
};
let unmarshalBody = '\tfor key, val := range rawMsg {\n';
// emitSwitchCase sets needsErrCheck so we must call it first
const switchCaseBody = emitSwitchCase();
if (needsErrCheck) {
unmarshalBody += '\t\tvar err error\n';
}
unmarshalBody += switchCaseBody;
if (needsErrCheck) {
unmarshalBody += '\t\tif err != nil {\n';
unmarshalBody += `\t\t\treturn fmt.Errorf("unmarshalling type %T: %v", ${receiver}, err)\n`;
unmarshalBody += '\t\t}\n';
}
unmarshalBody += '\t}\n'; // end for key, val := range rawMsg
unmarshalBody += '\treturn nil\n';
return unmarshalBody;
}
// returns true if item has a discriminator interface.
// recursively called for arrays and dictionaries.
function hasDiscriminatorInterface(item) {
switch (item.kind) {
case 'interface':
return true;
case 'map':
return hasDiscriminatorInterface(item.valueType);
case 'slice':
return hasDiscriminatorInterface(item.elementType);
default:
return false;
}
}
// returns the text for unmarshalling a discriminated type
function generateDiscriminatorUnmarshaller(field, receiver) {
const startingIndentation = '\t\t\t';
const propertyName = field.name;
// these are the simple, non-nested cases (e.g. IterfaceType, []InterfaceType, map[string]InterfaceType)
if (field.type.kind === 'interface') {
return `${startingIndentation}${receiver}.${propertyName}, err = unmarshal${field.type.name}(val)\n`;
}
else if (field.type.kind === 'slice' && field.type.elementType.kind === 'interface') {
return `${startingIndentation}${receiver}.${propertyName}, err = unmarshal${field.type.elementType.name}Array(val)\n`;
}
else if (field.type.kind === 'map' && field.type.valueType.kind === 'interface') {
return `${startingIndentation}${receiver}.${propertyName}, err = unmarshal${field.type.valueType.name}Map(val)\n`;
}
// nested case (e.g. [][]InterfaceType, map[string]map[string]InterfaceType etc)
// first, unmarshal the raw data
const rawTargetVar = `${field.serializedName}Raw`;
let text = `${startingIndentation}var ${rawTargetVar} ${recursiveGetDiscriminatorTypeName(field.type, true)}\n`;
text += `${startingIndentation}if err = json.Unmarshal(val, &${rawTargetVar}); err != nil {\n`;
text += `${startingIndentation}\treturn err\n${startingIndentation}}\n`;
// create a local instantiation of the final type
const finalTargetVar = field.serializedName;
let finalTargetCtor = recursiveGetDiscriminatorTypeName(field.type, false);
if (field.type.kind === 'slice') {
finalTargetCtor = `make(${finalTargetCtor}, len(${rawTargetVar}))`;
}
else {
// must be a dictionary
finalTargetCtor = `${finalTargetCtor}{}`;
}
text += `${startingIndentation}${finalTargetVar} := ${finalTargetCtor}\n`;
// now populate the final type
text += recursivePopulateDiscriminator(field.type, receiver, rawTargetVar, finalTargetVar, startingIndentation, 1);
// finally, assign the final target to the property
text += `${startingIndentation}${receiver}.${propertyName} = ${finalTargetVar}\n`;
return text;
}
// constructs the type name for a nested discriminated type
// raw e.g. map[string]json.RawMessage, []json.RawMessage etc
// !raw e.g. map[string]map[string]InterfaceType, [][]InterfaceType etc
function recursiveGetDiscriminatorTypeName(item, raw) {
// when raw is true, stop recursing at the level before the leaf schema
if (item.kind === 'slice') {
if (!raw || item.elementType.kind !== 'interface') {
return `[]${recursiveGetDiscriminatorTypeName(item.elementType, raw)}`;
}
}
else if (item.kind === 'map') {
if (!raw || item.valueType.kind !== 'interface') {
return `map[string]${recursiveGetDiscriminatorTypeName(item.valueType, raw)}`;
}
}
if (raw) {
return 'json.RawMessage';
}
return go.getTypeDeclaration(item);
}
// recursively constructs the text to populate a nested discriminator
function recursivePopulateDiscriminator(item, receiver, rawSrc, dest, indent, nesting) {
let text = '';
let interfaceName = '';
let targetType = '';
if (item.kind === 'slice') {
if (item.elementType.kind !== 'interface') {
if (nesting > 1) {
// at nestling level 1, the destination var was already created in generateDiscriminatorUnmarshaller()
text += `${indent}${dest} = make(${recursiveGetDiscriminatorTypeName(item, false)}, len(${rawSrc}))\n`;
}
text += `${indent}for i${nesting} := range ${rawSrc} {\n`;
rawSrc = `${rawSrc}[i${nesting}]`; // source becomes each element in the source slice
dest = `${dest}[i${nesting}]`; // update destination to each element in the destination slice
text += recursivePopulateDiscriminator(item.elementType, receiver, rawSrc, dest, indent + '\t', nesting + 1);
text += `${indent}}\n`;
return text;
}
// we're at leaf node - 1, so get the interface from the element's type
interfaceName = go.getTypeDeclaration(item.elementType);
targetType = 'Array';
}
else if (item.kind === 'map') {
if (item.valueType.kind !== 'interface') {
if (nesting > 1) {
// at nestling level 1, the destination var was already created in generateDiscriminatorUnmarshaller()
text += `${indent}${dest} = ${recursiveGetDiscriminatorTypeName(item, false)}{}\n`;
}
text += `${indent}for k${nesting}, v${nesting} := range ${rawSrc} {\n`;
rawSrc = `v${nesting}`; // source becomes the current value in the source map
dest = `${dest}[k${nesting}]`; // update destination to the destination map's value for the current key
text += recursivePopulateDiscriminator(item.valueType, receiver, rawSrc, dest, indent + '\t', nesting + 1);
text += `${indent}}\n`;
return text;
}
// we're at leaf node - 1, so get the interface from the element's type
interfaceName = go.getTypeDeclaration(item.valueType);
targetType = 'Map';
}
text += `${indent}${dest}, err = unmarshal${interfaceName}${targetType}(${rawSrc})\n`;
text += `${indent}if err != nil {\n${indent}\treturn fmt.Errorf("unmarshalling type %T: %v", ${receiver}, err)\n${indent}}\n`;
return text;
}
function generateXMLMarshaller(modelType, modelDef, imports) {
// only needed for types with time.Time, maps, or where the XML name doesn't match the type name
const receiver = modelDef.receiverName();
const desc = `MarshalXML implements the xml.Marshaller interface for type ${modelDef.Name}.`;
let text = `func (${receiver} ${modelDef.Name}) MarshalXML(enc *xml.Encoder, start xml.StartElement) error {\n`;
if (modelType.xml?.wrapper) {
text += `\tstart.Name.Local = "${modelType.xml.wrapper}"\n`;
}
text += generateAliasType(modelType, receiver, true);
for (const field of values(modelDef.Fields)) {
if (field.type.kind === 'slice') {
text += `\tif ${receiver}.${field.name} != nil {\n`;
text += `\t\taux.${field.name} = &${receiver}.${field.name}\n`;
text += '\t}\n';
}
else if (field.annotations.isAdditionalProperties || field.type.kind === 'map') {
text += `\taux.${field.name} = (additionalProperties)(${receiver}.${field.name})\n`;
}
else if (field.type.kind === 'encodedBytes') {
imports.add('github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime');
text += `\tif ${receiver}.${field.name} != nil {\n`;
text += `\t\tencoded${field.name} := runtime.EncodeByteArray(${receiver}.${field.name}, runtime.Base64${field.type.encoding}Format)\n`;
text += `\t\taux.${field.name} = &encoded${field.name}\n`;
text += '\t}\n';
}
}
text += '\treturn enc.EncodeElement(aux, start)\n';
text += '}\n\n';
modelDef.SerDe.methods.push({ name: 'MarshalXML', desc: desc, text: text });
}
function generateXMLUnmarshaller(modelType, modelDef, imports) {
// non-polymorphic case, must be something with time.Time
const receiver = modelDef.receiverName();
const desc = `UnmarshalXML implements the xml.Unmarshaller interface for type ${modelDef.Name}.`;
let text = `func (${receiver} *${modelDef.Name}) UnmarshalXML(dec *xml.Decoder, start xml.StartElement) error {\n`;
text += generateAliasType(modelType, receiver, false);
text += '\tif err := dec.DecodeElement(aux, &start); err != nil {\n';
text += '\t\treturn err\n';
text += '\t}\n';
for (const field of values(modelDef.Fields)) {
if (field.type.kind === 'time') {
text += `\tif aux.${field.name} != nil && !(*time.Time)(aux.${field.name}).IsZero() {\n`;
text += `\t\t${receiver}.${field.name} = (*time.Time)(aux.${field.name})\n\t}\n`;
}
else if (field.annotations.isAdditionalProperties || field.type.kind === 'map') {
text += `\t${receiver}.${field.name} = (map[string]*string)(aux.${field.name})\n`;
}
else if (field.type.kind === 'encodedBytes') {
imports.add('github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime');
text += `\tif aux.${field.name} != nil {\n`;
text += `\t\tif err := runtime.DecodeByteArray(*aux.${field.name}, &${receiver}.${field.name}, runtime.Base64${field.type.encoding}Format); err != nil {\n`;
text += '\t\t\treturn err\n';
text += '\t\t}\n';
text += '\t}\n';
}
}
text += '\treturn nil\n';
text += '}\n\n';
modelDef.SerDe.methods.push({ name: 'UnmarshalXML', desc: desc, text: text });
}
// generates an alias type used by custom XML marshaller/unmarshaller
function generateAliasType(modelType, receiver, forMarshal) {
let text = `\ttype alias ${modelType.name}\n`;
text += '\taux := &struct {\n';
text += '\t\t*alias\n';
for (const field of values(modelType.fields)) {
const sn = getXMLSerialization(field, false);
if (field.type.kind === 'time') {
text += `\t\t${field.name} *${field.type.format} \`xml:"${sn}"\`\n`;
}
else if (field.annotations.isAdditionalProperties || field.type.kind === 'map') {
text += `\t\t${field.name} additionalProperties \`xml:"${sn}"\`\n`;
}
else if (field.type.kind === 'slice') {
text += `\t\t${field.name} *${go.getTypeDeclaration(field.type)} \`xml:"${sn}"\`\n`;
}
else if (field.type.kind === 'encodedBytes') {
text += `\t\t${field.name} *string \`xml:"${sn}"\`\n`;
}
}
text += '\t}{\n';
let rec = receiver;
if (forMarshal) {
rec = '&' + rec;
}
text += `\t\talias: (*alias)(${rec}),\n`;
if (forMarshal) {
// emit code to initialize time fields
for (const field of modelType.fields) {
if (field.type.kind !== 'time') {
continue;
}
text += `\t\t${field.name}: (*${field.type.format})(${receiver}.${field.name}),\n`;
}
}
text += '\t}\n';
return text;
}
class SerDeInfo {
methods;
needsJSONPopulate;
needsJSONUnpopulate;
needsJSONPopulateByteArray;
needsJSONPopulateAny;
needsJSONPopulateMultipart;
constructor() {
this.methods = new Array();
this.needsJSONPopulate = false;
this.needsJSONUnpopulate = false;
this.needsJSONPopulateByteArray = false;
this.needsJSONPopulateAny = false;
this.needsJSONPopulateMultipart = false;
}
}
// represents model definition as a Go struct
class ModelDef {
Name;
Format;
Docs;
Fields;
SerDe;
Methods;
constructor(name, format, fields, docs) {
this.Name = name;
this.Format = format;
this.Docs = docs;
this.Fields = fields;
this.SerDe = new SerDeInfo();
this.Methods = new Array();
}
text() {
let text = helpers.formatDocComment(this.Docs);
text += `type ${this.Name} struct {\n`;
// group fields by required/optional/read-only in that order
this.Fields?.sort((lhs, rhs) => {
if ((lhs.annotations.required && !rhs.annotations.required) || (!lhs.annotations.readOnly && rhs.annotations.readOnly)) {
return -1;
}
else if ((rhs.annotations.readOnly && !lhs.annotations.readOnly) || (!rhs.annotations.readOnly && lhs.annotations.readOnly)) {
return 1;
}
else {
return 0;
}
});
// used to track when to add an extra \n between fields that have comments
let first = true;
for (const field of values(this.Fields)) {
if (field.docs.summary || field.docs.description) {
if (!first) {
// add an extra new-line between fields IFF the field
// has a comment and it's not the very first one.
text += '\n';
}
text += helpers.formatDocComment(field.docs);
}
let typeName = go.getTypeDeclaration(field.type);
if (field.type.kind === 'literal') {
// for constants we use the underlying type name
typeName = go.getLiteralTypeDeclaration(field.type.type);
}
let serialization = field.serializedName;
if (this.Format === 'JSON') {
serialization += ',omitempty';
}
else if (this.Format === 'XML') {
serialization = getXMLSerialization(field, false);
}
let tag = '';
// only emit tags for XML; JSON uses custom marshallers/unmarshallers
if (this.Format === 'XML' && !field.annotations.isAdditionalProperties) {
tag = ` \`xml:"${serialization}"\``;
}
text += `\t${field.name} ${getStar(field.byValue)}${typeName}${tag}\n`;
first = false;
}
text += '}\n\n';
return text;
}
receiverName() {
const typeName = this.Name;
return typeName[0].toLowerCase();
}
}
export function getXMLSerialization(field, isResponseEnvelope) {
let serialization = field.serializedName;
// default to using the serialization name
if (field.xml?.name) {
// xml can specifiy its own name, prefer that if available
serialization = field.xml.name;
}
else if (field.xml?.text) {
// type has the x-ms-text attribute applied so it should be character data, not a node (https://github.com/Azure/autorest/tree/main/docs/extensions#x-ms-text)
// see https://pkg.go.dev/encoding/xml#Unmarshal for what ,chardata actually means
serialization = ',chardata';
}
if (field.xml?.attribute) {
// value comes from an xml attribute
serialization += ',attr';
}
else if (field.type.kind === 'slice') {
// start with the serialized name of the element, preferring xml name if available
let inner = go.getTypeDeclaration(field.type.elementType);
if (field.xml?.name) {
inner = field.xml.name;
}
// arrays can be wrapped or unwrapped. here's a wrapped example
// note how the array of apple objects is "wrapped" in GoodApples
// <AppleBarrel>
// <GoodApples>
// <Apple>Fuji</Apple>
// <Apple>Gala</Apple>
// </GoodApples>
// </AppleBarrel>
// here's an unwrapped example, the array of slide objects
// is embedded directly in the object (no "wrapping")
// <slideshow>
// <slide>
// <title>Wake up to WonderWidgets!</title>
// </slide>
// <slide>
// <title>Overview</title>
// </slide>
// </slideshow>
// arrays in the response type are handled slightly different as we
// unmarshal directly into them so no need to add the unwrapping.
if (field.xml?.wraps && !isResponseEnvelope) {
serialization += `>${field.xml.wraps}`;
}
else {
serialization = inner;
}
}
return serialization;
}
export function getStar(byValue) {
if (byValue === true) {
return '';
}
return '*';
}
//# sourceMappingURL=models.js.map