File size: 3,020 Bytes
6e29063
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import { rerank } from "./rerankerService";

export async function rankSearchResults(
  query: string,
  searchResults: [title: string, content: string, url: string][],
  preserveTopResults = false,
) {
  const documents = searchResults.map(([title, snippet, url]) =>
    `[${title}](${url} "${snippet.replaceAll('"', "'")}")`.toLocaleLowerCase(),
  );

  const results = await rerank(query.toLocaleLowerCase(), documents);

  const scoredResults = results.map(({ index, relevance_score }) => ({
    result: searchResults[index],
    score: relevance_score,
  }));

  if (scoredResults.length === 0) {
    return [];
  }

  if (!preserveTopResults) {
    return filterResultsByScore(scoredResults)
      .sort((a, b) => b.score - a.score)
      .map(({ result }) => result);
  }

  const [firstResult, ...nextResults] = scoredResults;

  const filteredNextResults = filterResultsByScore(nextResults);

  const nextTopResultsCount = 9;

  const nextTopResults = filteredNextResults
    .slice(0, nextTopResultsCount)
    .sort((a, b) => b.score - a.score);

  const remainingResults = filteredNextResults
    .slice(nextTopResultsCount)
    .sort((a, b) => b.score - a.score);

  return [firstResult, ...nextTopResults, ...remainingResults].map(
    ({ result }) => result,
  );
}

type SearchResultTuple = [title: string, content: string, url: string];
type ScoredResultItem = { result: SearchResultTuple; score: number };
type ScoredResultItemWithNormalizedScore = ScoredResultItem & {
  normalizedScore: number;
};

function filterResultsByScore(
  inputResults: ScoredResultItem[],
  kStandardDeviationFactor = 0.3,
  minPercentageFallback = 0.4,
): ScoredResultItemWithNormalizedScore[] {
  if (inputResults.length === 0) return [];

  const originalScores = inputResults.map(({ score }) => score);
  const minScore = Math.min(...originalScores);

  const itemsWithNormalizedScore = inputResults.map((item) => ({
    ...item,
    normalizedScore: item.score + Math.abs(minScore),
  }));

  const normalizedScores = itemsWithNormalizedScore.map(
    ({ normalizedScore }) => normalizedScore,
  );

  const mean =
    normalizedScores.reduce((sum, score) => sum + score, 0) /
    normalizedScores.length;
  const variance =
    normalizedScores.reduce((sum, score) => sum + (score - mean) ** 2, 0) /
    normalizedScores.length;
  const standardDeviation = Math.sqrt(variance);

  const threshold = Math.max(
    0,
    mean - kStandardDeviationFactor * standardDeviation,
  );

  let filteredItems = itemsWithNormalizedScore.filter(
    ({ normalizedScore }) => normalizedScore >= threshold,
  );

  if (
    filteredItems.length <
      Math.ceil(itemsWithNormalizedScore.length * minPercentageFallback) &&
    itemsWithNormalizedScore.length > 0
  ) {
    const highestNormalizedScore = Math.max(...normalizedScores);
    filteredItems = itemsWithNormalizedScore.filter(
      ({ normalizedScore }) =>
        normalizedScore >= highestNormalizedScore * minPercentageFallback,
    );
  }

  return filteredItems;
}