UNPKG

@eeacms/volto-chatbot

Version:

@eeacms/volto-chatbot: Volto add-on

201 lines (180 loc) 5.57 kB
import fetch from 'node-fetch'; import { getClaimsFromResponse, getClassifierProbabilitiesFromLogits, getTokenProbabilitiesFromLogits, } from './postprocessing'; import { createHalloumiClassifierPrompts, createHalloumiPrompt, } from './preprocessing'; function sigmoid(x) { return 1 / (1 + Math.exp(-x)); } export function applyPlattScaling(platt, probability) { probability = Math.min(Math.max(probability, 1e-6), 1 - 1e-6); const log_prob = Math.log(probability / (1 - probability)); return sigmoid(-1 * (platt.a * log_prob + platt.b)); } export async function halloumiClassifierAPI(model, context, claims) { const classifierPrompts = createHalloumiClassifierPrompts(context, claims); const headers = { 'Content-Type': 'application/json', accept: 'application/json', }; if (model.apiKey) { headers['Authorization'] = `Bearer ${model.apiKey}`; } const data = { input: classifierPrompts.prompts, model: model.name, }; const response = await fetch(model.apiUrl, { method: 'POST', headers: headers, body: JSON.stringify(data), }); const jsonData = await response.json(); const output = { claims: [], }; for (let i = 0; i < classifierPrompts.prompts.length; i++) { const embedding = jsonData.data[i].embedding; const probs = getClassifierProbabilitiesFromLogits(embedding); if (model.plattScaling) { const platt = model.plattScaling; const unsupportedScore = applyPlattScaling(platt, probs[1]); const supportedScore = 1 - unsupportedScore; probs[0] = supportedScore; probs[1] = unsupportedScore; } const offset = classifierPrompts.responseOffsets.get(i + 1); // 0-th index is the supported class. // 1-th index is the unsupported class. output.claims.push({ startOffset: offset.startOffset, endOffset: offset.endOffset, citationIds: [], score: probs[0], rationale: '', }); } return output; } export async function getVerifyClaimResponse(model, context, claims) { if (!context || !claims) { const response = { claims: [], citations: {}, }; return response; } if (model.isEmbeddingModel) { return halloumiClassifierAPI(model, context, claims).then((response) => { const parsedResponse = { claims: response.claims, citations: {}, }; return parsedResponse; }); } const prompt = createHalloumiPrompt(context, claims); const result = await halloumiGenerativeAPI(model, prompt).then((claims) => { return convertGenerativesClaimToVerifyClaimResponse(claims, prompt); }); return result; } const tokenChoices = new Set(['supported', 'unsupported']); /** * Gets all claims from a response. * @param response A string containing all claims and their information. * @returns A list of claim objects. */ export async function halloumiGenerativeAPI(model, prompt) { const data = { messages: [{ role: 'user', content: prompt.prompt }], temperature: 0.0, model: model.name, logprobs: true, top_logprobs: 3, }; const headers = { 'Content-Type': 'application/json', accept: 'application/json', }; if (model.apiKey) { headers['Authorization'] = `Bearer ${model.apiKey}`; } const response = await fetch(model.apiUrl, { method: 'POST', headers: headers, body: JSON.stringify(data), }); const jsonData = await response.json(); const logits = jsonData.choices[0].logprobs.content; const tokenProbabilities = getTokenProbabilitiesFromLogits( logits, tokenChoices, ); const parsedResponse = getClaimsFromResponse( jsonData.choices[0].message.content, ); if (parsedResponse.length !== tokenProbabilities.length) { throw new Error('Token probabilities and claims do not match.'); } for (let i = 0; i < parsedResponse.length; i++) { const scoreMap = tokenProbabilities[i]; if (model.plattScaling) { const platt = model.plattScaling; const unsupportedScore = applyPlattScaling( platt, scoreMap.get('unsupported'), ); const supportedScore = 1 - unsupportedScore; scoreMap.set('supported', supportedScore); scoreMap.set('unsupported', unsupportedScore); } parsedResponse[i].probabilities = scoreMap; } return parsedResponse; } export function convertGenerativesClaimToVerifyClaimResponse( generativeClaims, prompt, ) { const citations = {}; const claims = []; for (const offset of prompt.contextOffsets) { const citation = { startOffset: offset[1].startOffset, endOffset: offset[1].endOffset, id: offset[0].toString(), }; citations[offset[0].toString()] = citation; } for (const generativeClaim of generativeClaims) { const citationIds = []; for (const citation of generativeClaim.citations) { citationIds.push(citation.toString()); } const claimId = generativeClaim.claimId; if (!prompt.responseOffsets.has(claimId)) { throw new Error(`Claim ${claimId} not found in response offsets.`); } const claimResponseWindow = prompt.responseOffsets.get(claimId); const score = generativeClaim.probabilities.get('supported'); const claim = { startOffset: claimResponseWindow.startOffset, endOffset: claimResponseWindow.endOffset, citationIds: citationIds, score: score, rationale: generativeClaim.explanation, }; claims.push(claim); } const response = { claims: claims, citations: citations, }; return response; }