File size: 6,326 Bytes
ab4488b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
from typing import List, Optional, Dict, cast

import base64
import itertools
import json

from azure.core.paging import ItemPaged, PageIterator, ReturnType
from ._generated.models import SearchRequest, SearchDocumentsResult, QueryAnswerResult
from ._api_versions import DEFAULT_VERSION


def convert_search_result(result):
    ret = result.additional_properties
    ret["@search.score"] = result.score
    ret["@search.reranker_score"] = result.reranker_score
    ret["@search.highlights"] = result.highlights
    ret["@search.captions"] = result.captions
    return ret


def pack_continuation_token(response, api_version=DEFAULT_VERSION):
    if response.next_page_parameters is not None:
        token = {
            "apiVersion": api_version,
            "nextLink": response.next_link,
            "nextPageParameters": response.next_page_parameters.serialize(),
        }
        return base64.b64encode(json.dumps(token).encode("utf-8"))
    return None


def unpack_continuation_token(token):
    unpacked_token = json.loads(base64.b64decode(token))
    next_link = unpacked_token["nextLink"]
    next_page_parameters = unpacked_token["nextPageParameters"]
    next_page_request = SearchRequest.deserialize(next_page_parameters)
    return next_link, next_page_request


class SearchItemPaged(ItemPaged[ReturnType]):
    def __init__(self, *args, **kwargs) -> None:
        super(SearchItemPaged, self).__init__(*args, **kwargs)
        self._first_page_iterator_instance: Optional[SearchPageIterator] = None

    def __next__(self) -> ReturnType:
        if self._page_iterator is None:
            first_iterator = self._first_iterator_instance()
            self._page_iterator = itertools.chain.from_iterable(first_iterator)
        return next(self._page_iterator)

    def _first_iterator_instance(self) -> "SearchPageIterator":
        if self._first_page_iterator_instance is None:
            self._first_page_iterator_instance = cast(SearchPageIterator, self.by_page())
        return self._first_page_iterator_instance

    def get_facets(self) -> Optional[Dict]:
        """Return any facet results if faceting was requested.

        :return: facet results
        :rtype: dict or None
        """
        return cast(Dict, self._first_iterator_instance().get_facets())

    def get_coverage(self) -> float:
        """Return the coverage percentage, if `minimum_coverage` was
        specificied for the query.

        :return: coverage percentage
        :rtype: float
        """
        return cast(float, self._first_iterator_instance().get_coverage())

    def get_count(self) -> int:
        """Return the count of results if `include_total_count` was
        set for the query.

        :return: count of results
        :rtype: int
        """
        return cast(int, self._first_iterator_instance().get_count())

    def get_answers(self) -> Optional[List[QueryAnswerResult]]:
        """Return semantic answers. Only included if the semantic ranker is used
        and answers are requested in the search query via the query_answer parameter.

        :return: answers
        :rtype: list[~azure.search.documents.models.QueryAnswerResult] or None
        """
        return cast(List[QueryAnswerResult], self._first_iterator_instance().get_answers())


# The pylint error silenced below seems spurious, as the inner wrapper does, in
# fact, become a method of the class when it is applied.
def _ensure_response(f):
    # pylint:disable=protected-access
    def wrapper(self, *args, **kw):
        if self._current_page is None:
            self._response = self._get_next(self.continuation_token)
            self.continuation_token, self._current_page = self._extract_data(self._response)
        return f(self, *args, **kw)

    return wrapper


class SearchPageIterator(PageIterator):
    def __init__(self, client, initial_query, kwargs, continuation_token=None) -> None:
        super(SearchPageIterator, self).__init__(
            get_next=self._get_next_cb,
            extract_data=self._extract_data_cb,
            continuation_token=continuation_token,
        )
        self._client = client
        self._initial_query = initial_query
        self._kwargs = kwargs
        self._facets = None
        self._api_version = kwargs.pop("api_version", DEFAULT_VERSION)

    def _get_next_cb(self, continuation_token):
        if continuation_token is None:
            return self._client.documents.search_post(search_request=self._initial_query.request, **self._kwargs)

        _next_link, next_page_request = unpack_continuation_token(continuation_token)

        return self._client.documents.search_post(search_request=next_page_request, **self._kwargs)

    def _extract_data_cb(self, response):
        continuation_token = pack_continuation_token(response, api_version=self._api_version)
        results = [convert_search_result(r) for r in response.results]
        return continuation_token, results

    @_ensure_response
    def get_facets(self) -> Optional[Dict]:
        self.continuation_token = None
        response = cast(SearchDocumentsResult, self._response)
        facets = response.facets
        if facets is not None and self._facets is None:
            assert facets.items() is not None  # Hint for mypy
            self._facets = {k: [x.as_dict() for x in v] for k, v in facets.items()}
        return self._facets

    @_ensure_response
    def get_coverage(self) -> float:
        self.continuation_token = None
        response = cast(SearchDocumentsResult, self._response)
        return cast(float, response.coverage)

    @_ensure_response
    def get_count(self) -> int:
        self.continuation_token = None
        response = cast(SearchDocumentsResult, self._response)
        return cast(int, response.count)

    @_ensure_response
    def get_answers(self) -> Optional[List[QueryAnswerResult]]:
        self.continuation_token = None
        response = cast(SearchDocumentsResult, self._response)
        return response.answers