|
import os |
|
from typing import Any |
|
|
|
from huggingface_hub import InferenceClient |
|
|
|
from rag_demo.rag.base.query import Query |
|
from rag_demo.rag.base.template_factory import RAGStep |
|
from rag_demo.rag.prompt_templates import QueryExpansionTemplate |
|
|
|
|
|
class QueryExpansion(RAGStep): |
|
def generate(self, query: Query, expand_to_n: int) -> Any: |
|
api = InferenceClient( |
|
model="Qwen/Qwen2.5-72B-Instruct", |
|
token=os.getenv("HF_API_TOKEN"), |
|
) |
|
query_expansion_template = QueryExpansionTemplate() |
|
prompt = query_expansion_template.create_template(expand_to_n - 1) |
|
response = api.chat_completion( |
|
[ |
|
{ |
|
"role": "user", |
|
"content": prompt.template.format( |
|
question=query.content, |
|
expand_to_n=expand_to_n, |
|
separator=query_expansion_template.separator, |
|
), |
|
} |
|
] |
|
) |
|
result = response.choices[0].message.content |
|
queries_content = result.split(query_expansion_template.separator) |
|
queries = [query] |
|
queries += [ |
|
query.replace_content(stripped_content) |
|
for content in queries_content |
|
if (stripped_content := content.strip()) |
|
] |
|
return queries |
|
|