@neo4j/graphql
Version:
A GraphQL to Cypher query execution layer for Neo4j and JavaScript GraphQL implementations
125 lines • 5.4 kB
JavaScript
;
/*
* Copyright (c) "Neo4j"
* Neo4j Sweden AB [http://neo4j.com]
*
* This file is part of Neo4j.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
Object.defineProperty(exports, "__esModule", { value: true });
exports.VectorFactory = void 0;
const check_authentication_1 = require("../../../authorization/check-authentication");
const ScoreField_1 = require("../../ast/fields/ScoreField");
const ScoreFilter_1 = require("../../ast/filters/property-filters/ScoreFilter");
const VectorOperation_1 = require("../../ast/operations/VectorOperation");
const VectorSelection_1 = require("../../ast/selection/VectorSelection");
const find_fields_by_name_in_fields_by_type_name_field_1 = require("../parsers/find-fields-by-name-in-fields-by-type-name-field");
const get_fields_by_type_name_1 = require("../parsers/get-fields-by-type-name");
class VectorFactory {
constructor(queryASTFactory) {
this.queryASTFactory = queryASTFactory;
}
createVectorOperation(entity, resolveTree, context) {
const resolveTreeWhere = this.queryASTFactory.operationsFactory.getWhereArgs(resolveTree) ?? {};
(0, check_authentication_1.checkEntityAuthentication)({
entity: entity.entity,
targetOperations: ["READ"],
context,
});
let scoreField;
const vectorConnectionFields = resolveTree.fieldsByTypeName[entity.operations.vectorTypeNames.connection];
if (!vectorConnectionFields) {
throw new Error("Vector result field not found");
}
const filteredResolveTreeEdges = (0, find_fields_by_name_in_fields_by_type_name_field_1.findFieldsByNameInFieldsByTypeNameField)(vectorConnectionFields, "edges");
const edgeFields = (0, get_fields_by_type_name_1.getFieldsByTypeName)(filteredResolveTreeEdges, entity.operations.vectorTypeNames.edge);
const scoreFields = (0, find_fields_by_name_in_fields_by_type_name_field_1.findFieldsByNameInFieldsByTypeNameField)(edgeFields, "score");
// We only care about the first score field
if (scoreFields.length > 0 && scoreFields[0] && context.vector) {
scoreField = new ScoreField_1.ScoreField({
alias: scoreFields[0].alias,
score: context.vector.scoreVariable,
});
}
const operation = new VectorOperation_1.VectorOperation({
target: entity,
scoreField,
selection: this.getVectorSelection(entity, context),
});
const concreteEdgeFields = (0, get_fields_by_type_name_1.getFieldsByTypeName)(filteredResolveTreeEdges, entity.operations.vectorTypeNames.edge);
this.addVectorScoreFilter({
operation,
context,
whereArgs: resolveTreeWhere,
});
this.queryASTFactory.operationsFactory.hydrateConnectionOperation({
target: entity,
resolveTree: resolveTree,
context,
operation,
whereArgs: resolveTreeWhere,
resolveTreeEdgeFields: concreteEdgeFields,
});
return operation;
}
addVectorScoreFilter({ operation, whereArgs, context, }) {
if (whereArgs.score && context?.vector) {
const scoreFilter = new ScoreFilter_1.ScoreFilter({
scoreVariable: context.vector.scoreVariable,
min: whereArgs.score.min,
max: whereArgs.score.max,
});
operation.addFilters(scoreFilter);
}
}
getVectorSelection(entity, context) {
const vectorOptions = this.getVectorOptions(context);
return new VectorSelection_1.VectorSelection({
target: entity,
vectorOptions,
scoreVariable: vectorOptions.score,
settings: context.vector?.vectorSettings,
});
}
getVectorOptions(context) {
if (!context.vector) {
throw new Error("Vector context is missing");
}
if (context.resolveTree.args.vector) {
const vector = context.resolveTree.args.vector;
if (!Array.isArray(vector)) {
throw new Error("Invalid vector");
}
if (!vector.every((v) => typeof v === "number")) {
throw new Error("Invalid vector");
}
return {
index: context.vector.index,
vector,
score: context.vector.scoreVariable,
};
}
const phrase = context.resolveTree.args.phrase;
if (!phrase || typeof phrase !== "string") {
throw new Error("Invalid phrase");
}
return {
index: context.vector.index,
phrase,
score: context.vector.scoreVariable,
};
}
}
exports.VectorFactory = VectorFactory;
//# sourceMappingURL=VectorFactory.js.map