Spaces:
Running
Running
File size: 3,020 Bytes
6e29063 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
import { rerank } from "./rerankerService";
export async function rankSearchResults(
query: string,
searchResults: [title: string, content: string, url: string][],
preserveTopResults = false,
) {
const documents = searchResults.map(([title, snippet, url]) =>
`[${title}](${url} "${snippet.replaceAll('"', "'")}")`.toLocaleLowerCase(),
);
const results = await rerank(query.toLocaleLowerCase(), documents);
const scoredResults = results.map(({ index, relevance_score }) => ({
result: searchResults[index],
score: relevance_score,
}));
if (scoredResults.length === 0) {
return [];
}
if (!preserveTopResults) {
return filterResultsByScore(scoredResults)
.sort((a, b) => b.score - a.score)
.map(({ result }) => result);
}
const [firstResult, ...nextResults] = scoredResults;
const filteredNextResults = filterResultsByScore(nextResults);
const nextTopResultsCount = 9;
const nextTopResults = filteredNextResults
.slice(0, nextTopResultsCount)
.sort((a, b) => b.score - a.score);
const remainingResults = filteredNextResults
.slice(nextTopResultsCount)
.sort((a, b) => b.score - a.score);
return [firstResult, ...nextTopResults, ...remainingResults].map(
({ result }) => result,
);
}
type SearchResultTuple = [title: string, content: string, url: string];
type ScoredResultItem = { result: SearchResultTuple; score: number };
type ScoredResultItemWithNormalizedScore = ScoredResultItem & {
normalizedScore: number;
};
function filterResultsByScore(
inputResults: ScoredResultItem[],
kStandardDeviationFactor = 0.3,
minPercentageFallback = 0.4,
): ScoredResultItemWithNormalizedScore[] {
if (inputResults.length === 0) return [];
const originalScores = inputResults.map(({ score }) => score);
const minScore = Math.min(...originalScores);
const itemsWithNormalizedScore = inputResults.map((item) => ({
...item,
normalizedScore: item.score + Math.abs(minScore),
}));
const normalizedScores = itemsWithNormalizedScore.map(
({ normalizedScore }) => normalizedScore,
);
const mean =
normalizedScores.reduce((sum, score) => sum + score, 0) /
normalizedScores.length;
const variance =
normalizedScores.reduce((sum, score) => sum + (score - mean) ** 2, 0) /
normalizedScores.length;
const standardDeviation = Math.sqrt(variance);
const threshold = Math.max(
0,
mean - kStandardDeviationFactor * standardDeviation,
);
let filteredItems = itemsWithNormalizedScore.filter(
({ normalizedScore }) => normalizedScore >= threshold,
);
if (
filteredItems.length <
Math.ceil(itemsWithNormalizedScore.length * minPercentageFallback) &&
itemsWithNormalizedScore.length > 0
) {
const highestNormalizedScore = Math.max(...normalizedScores);
filteredItems = itemsWithNormalizedScore.filter(
({ normalizedScore }) =>
normalizedScore >= highestNormalizedScore * minPercentageFallback,
);
}
return filteredItems;
}
|