annikwag's picture
Update app.py
d237e1f verified
raw
history blame
19.2 kB
import streamlit as st
import requests
import pandas as pd
from appStore.prep_data import process_giz_worldwide, remove_duplicates, get_max_end_year, extract_year
from appStore.prep_utils import create_documents, get_client
from appStore.embed import hybrid_embed_chunks
from appStore.search import hybrid_search
from appStore.region_utils import load_region_data, get_country_name, get_regions
from appStore.tfidf_extraction import extract_top_keywords
from torch import cuda
import json
from datetime import datetime
#model_config = getconfig("model_params.cfg")
###########
# ToDo move to functions
# Configuration for the dedicated model
DEDICATED_MODEL = "meta-llama/Llama-3.1-8B-Instruct"
DEDICATED_ENDPOINT = "https://qu2d8m6dmsollhly.us-east-1.aws.endpoints.huggingface.cloud"
# Write access token from the settings
WRITE_ACCESS_TOKEN = st.secrets["Llama_3_1"]
def get_rag_answer(query, top_results):
"""
Constructs a prompt from the query and the page contexts of the top results,
truncates the context to avoid exceeding the token limit, then sends it to the
dedicated endpoint and returns only the generated answer.
"""
# Combine the context from the top results (adjust the separator as needed)
context = "\n\n".join([res.payload["page_content"] for res in top_results])
# Truncate the context to a maximum number of characters (e.g., 12000 characters)
max_context_chars = 15000
if len(context) > max_context_chars:
context = context[:max_context_chars]
# Build the prompt, instructing the model to only output the final answer.
prompt = (
"Using the following context, answer the question concisely. "
"Only output the final answer below, without repeating the context or question.\n\n"
f"Context:\n{context}\n\n"
f"Question: {query}\n\n"
"Answer:"
)
headers = {"Authorization": f"Bearer {WRITE_ACCESS_TOKEN}"}
payload = {
"inputs": prompt,
"parameters": {
"max_new_tokens": 150 # Adjust max tokens as needed
}
}
response = requests.post(DEDICATED_ENDPOINT, headers=headers, json=payload)
if response.status_code == 200:
result = response.json()
answer = result[0]["generated_text"]
# If the model returns the full prompt, split and extract only the portion after "Answer:"
if "Answer:" in answer:
answer = answer.split("Answer:")[-1].strip()
return answer
else:
return f"Error in generating answer: {response.text}"
#######
# get the device to be used eithe gpu or cpu
device = 'cuda' if cuda.is_available() else 'cpu'
st.set_page_config(page_title="SEARCH IATI",layout='wide')
st.title("GIZ Project Database (PROTOTYPE)")
var = st.text_input("Enter Search Query")
# Load the region lookup CSV
region_lookup_path = "docStore/regions_lookup.csv"
region_df = load_region_data(region_lookup_path)
#################### Create the embeddings collection and save ######################
# the steps below need to be performed only once and then commented out any unnecssary compute over-run
##### First we process and create the chunks for relvant data source
#chunks = process_giz_worldwide()
##### Convert to langchain documents
#temp_doc = create_documents(chunks,'chunks')
##### Embed and store docs, check if collection exist then you need to update the collection
collection_name = "giz_worldwide"
#hybrid_embed_chunks(docs=temp_doc, collection_name=collection_name, del_if_exists=True)
################### Hybrid Search #####################################################
client = get_client()
print(client.get_collections())
# Get the maximum end_year across the entire collection
max_end_year = get_max_end_year(client, collection_name)
# Get all unique sub-regions
_, unique_sub_regions = get_regions(region_df)
# Fetch unique country codes and map to country names
@st.cache_data
def get_country_name_and_region_mapping(_client, collection_name, region_df):
results = hybrid_search(_client, "", collection_name)
country_set = set()
for res in results[0] + results[1]:
countries = res.payload.get('metadata', {}).get('countries', "[]")
try:
country_list = json.loads(countries.replace("'", '"'))
# Only add codes of length 2
two_digit_codes = [code.upper() for code in country_list if len(code) == 2]
country_set.update(two_digit_codes)
except json.JSONDecodeError:
pass
# Create a mapping of {CountryName -> ISO2Code} and {ISO2Code -> SubRegion}
country_name_to_code = {}
iso_code_to_sub_region = {}
for code in country_set:
name = get_country_name(code, region_df)
sub_region_row = region_df[region_df['alpha-2'] == code]
sub_region = sub_region_row['sub-region'].values[0] if not sub_region_row.empty else "Not allocated"
country_name_to_code[name] = code
iso_code_to_sub_region[code] = sub_region
return country_name_to_code, iso_code_to_sub_region
# Get country name and region mappings
client = get_client()
country_name_mapping, iso_code_to_sub_region = get_country_name_and_region_mapping(client, collection_name, region_df)
unique_country_names = sorted(country_name_mapping.keys()) # List of country names
# Layout filters in columns
col1, col2, col3, col4 = st.columns([1, 1, 1, 4])
# Region filter
with col1:
region_filter = st.selectbox("Region", ["All/Not allocated"] + sorted(unique_sub_regions)) # Display region names
# Dynamically filter countries based on selected region
if region_filter == "All/Not allocated":
filtered_country_names = unique_country_names # Show all countries if no region is selected
else:
filtered_country_names = [
name for name, code in country_name_mapping.items() if iso_code_to_sub_region.get(code) == region_filter
]
# Country filter
with col2:
country_filter = st.selectbox("Country", ["All/Not allocated"] + filtered_country_names) # Display filtered country names
# Year range slider # ToDo add end_year filter again
with col3:
current_year = datetime.now().year
default_start_year = current_year - 5
# 3) The max_value is now the actual max end_year from collection
end_year_range = st.slider(
"Project End Year",
min_value=2010,
max_value=max_end_year,
value=(default_start_year, max_end_year),
)
# Checkbox to control whether to show only exact matches
show_exact_matches = st.checkbox("Show only exact matches", value=False)
def filter_results(results, country_filter, region_filter, end_year_range): ## ToDo add end_year filter again
filtered = []
for r in results:
metadata = r.payload.get('metadata', {})
countries = metadata.get('countries', "[]")
year_str = metadata.get('end_year')
if year_str:
extracted = extract_year(year_str)
try:
end_year_val = int(extracted) if extracted != "Unknown" else 0
except ValueError:
end_year_val = 0
else:
end_year_val = 0
# Convert countries to a list
try:
c_list = json.loads(countries.replace("'", '"'))
c_list = [code.upper() for code in c_list if len(code) == 2]
except json.JSONDecodeError:
c_list = []
# Translate selected country name to iso2
selected_iso_code = country_name_mapping.get(country_filter, None)
# Check if any country in the metadata matches the selected region
if region_filter != "All/Not allocated":
countries_in_region = [code for code in c_list if iso_code_to_sub_region.get(code) == region_filter]
else:
countries_in_region = c_list
# Filtering
if (
(country_filter == "All/Not allocated" or selected_iso_code in c_list)
and (region_filter == "All/Not allocated" or countries_in_region)
and (end_year_range[0] <= end_year_val <= end_year_range[1]) # ToDo add end_year filter again
):
filtered.append(r)
return filtered
# Run the search
# 1) Adjust limit so we get more than 15 results
results = hybrid_search(client, var, collection_name, limit=500) # e.g., 100 or 200
# results is a tuple: (semantic_results, lexical_results)
semantic_all = results[0]
lexical_all = results[1]
# 2) Filter out content < 20 chars (as intermediate fix to problem that e.g. super short paragraphs with few chars get high similarity score)
semantic_all = [
r for r in semantic_all if len(r.payload["page_content"]) >= 5
]
lexical_all = [
r for r in lexical_all if len(r.payload["page_content"]) >= 5
]
# 2) Apply a threshold to SEMANTIC results (score >= 0.4)
semantic_thresholded = [r for r in semantic_all if r.score >= 0.0]
# 2) Filter the entire sets
filtered_semantic = filter_results(semantic_thresholded, country_filter, region_filter, end_year_range) ## ToDo add end_year filter again
filtered_lexical = filter_results(lexical_all, country_filter, region_filter, end_year_range)## ToDo add end_year filter again
filtered_semantic_no_dupe = remove_duplicates(filtered_semantic) # ToDo remove duplicates again?
filtered_lexical_no_dupe = remove_duplicates(filtered_lexical)
# Define a helper function to format currency values
def format_currency(value):
try:
# Convert to float then int for formatting (assumes whole numbers)
return f"€{int(float(value)):,}"
except (ValueError, TypeError):
return value
# 3) Retrieve top 15 *after* filtering
# Check user preference
if show_exact_matches:
# 1) Display heading
st.write(f"Showing **Top 15 Lexical Search results** for query: {var}")
# 2) Do a simple substring check (case-insensitive)
# We'll create a new list lexical_substring_filtered
query_substring = var.strip().lower()
lexical_substring_filtered = []
for r in lexical_all:
# page_content in lowercase
page_text_lower = r.payload["page_content"].lower()
# Keep this result only if the query substring is found
if query_substring in page_text_lower:
lexical_substring_filtered.append(r)
# 3) Now apply your region/country/year filter on that new list
filtered_lexical = filter_results(
lexical_substring_filtered, country_filter, region_filter, end_year_range
) ## ToDo add end_year filter again
# 4) Remove duplicates
filtered_lexical_no_dupe = remove_duplicates(filtered_lexical)
# 5) If empty after substring + filters + dedupe, show a custom message
if not filtered_lexical_no_dupe:
st.write('No exact matches, consider unchecking "Show only exact matches"')
else:
top_results = filtered_lexical_no_dupe[:5]
rag_answer = get_rag_answer(var, top_results)
st.markdown("### Generated Answer")
st.write(rag_answer)
st.divider()
for res in top_results:
# Metadata
metadata = res.payload.get('metadata', {})
countries = metadata.get('countries', "[]")
client_name = metadata.get('client', 'Unknown Client')
start_year = metadata.get('start_year', None)
end_year = metadata.get('end_year', None)
total_volume = metadata.get('total_volume', "Unknown")
total_project = metadata.get('total_project', "Unknown")
id = metadata.get('id', "Unknown")
project_name = res.payload['metadata'].get('project_name', 'Project Link')
proj_id = metadata.get('id', 'Unknown')
st.markdown(f"#### {project_name} [{proj_id}]")
# Snippet logic (80 words)
# Build snippet from objectives and descriptions.
objectives = metadata.get("objectives", "")
desc_de = metadata.get("description.de", "")
desc_en = metadata.get("description.en", "")
description = desc_de if desc_de else desc_en
full_snippet = f"Objective: {objectives} Description: {description}"
words = full_snippet.split()
preview_word_count = 200
preview_text = " ".join(words[:preview_word_count])
remainder_text = " ".join(words[preview_word_count:])
st.write(preview_text + ("..." if remainder_text else ""))
# Keywords
full_text = res.payload['page_content']
top_keywords = extract_top_keywords(full_text, top_n=5)
if top_keywords:
st.markdown(f"_{' · '.join(top_keywords)}_")
try:
c_list = json.loads(countries.replace("'", '"'))
except json.JSONDecodeError:
c_list = []
# Only keep country names if the region lookup returns a different value.
matched_countries = []
for code in c_list:
if len(code) == 2:
resolved_name = get_country_name(code.upper(), region_df)
if resolved_name.upper() != code.upper():
matched_countries.append(resolved_name)
# Format the year range
start_year_str = extract_year(start_year) if start_year else "Unknown"
end_year_str = extract_year(end_year) if end_year else "Unknown"
formatted_project_budget = format_currency(total_project)
formatted_total_volume = format_currency(total_volume)
# Build the final string including a new row for countries.
if matched_countries:
additional_text = (
f"**{', '.join(matched_countries)}**, commissioned by **{client_name}**\n"
f"Projekt duration **{start_year_str}-{end_year_str}**\n"
f"Budget: Project: **{formatted_project_budget}**, Total volume: **{formatted_total_volume}**\n"
f"Country: **{', '.join(matched_countries)}**"
)
else:
additional_text = (
f"Commissioned by **{client_name}**\n"
f"Projekt duration **{start_year_str}-{end_year_str}**\n"
f"Budget: Project: **{formatted_project_budget}**, Total volume: **{formatted_total_volume}**\n"
f"Country: **{', '.join(c_list) if c_list else 'Unknown'}**"
)
st.markdown(additional_text)
st.divider()
else:
st.write(f"Showing **Top 15 Semantic Search results** for query: {var}")
if not filtered_semantic_no_dupe:
st.write("No relevant results found.")
else:
# Get the top 15 results for the RAG context
top_results = filtered_semantic_no_dupe[:5]
# Call the RAG function to generate an answer
rag_answer = get_rag_answer(var, top_results)
# Display the generated answer at the top of the page
st.markdown("### Generated Answer")
st.write(rag_answer)
st.divider()
# Now list each individual search result below
for res in top_results:
# Metadata
metadata = res.payload.get('metadata', {})
countries = metadata.get('countries', "[]")
client_name = metadata.get('client', 'Unknown Client')
start_year = metadata.get('start_year', None)
end_year = metadata.get('end_year', None)
total_volume = metadata.get('total_volume', "Unknown")
total_project = metadata.get('total_project', "Unknown")
id = metadata.get('id', "Unknown")
project_name = res.payload['metadata'].get('project_name', 'Project Link')
proj_id = metadata.get('id', 'Unknown')
st.markdown(f"#### {project_name} [{proj_id}]")
# Snippet logic (80 words)
# Build snippet from objectives and descriptions.
objectives = metadata.get("objectives", "")
desc_de = metadata.get("description.de", "")
desc_en = metadata.get("description.en", "")
description = desc_de if desc_de else desc_en
full_snippet = f"Objective: {objectives} Description: {description}"
words = full_snippet.split()
preview_word_count = 200
preview_text = " ".join(words[:preview_word_count])
remainder_text = " ".join(words[preview_word_count:])
st.write(preview_text + ("..." if remainder_text else ""))
# Keywords
full_text = res.payload['page_content']
top_keywords = extract_top_keywords(full_text, top_n=5)
if top_keywords:
st.markdown(f"_{' · '.join(top_keywords)}_")
try:
c_list = json.loads(countries.replace("'", '"'))
except json.JSONDecodeError:
c_list = []
matched_countries = []
for code in c_list:
if len(code) == 2:
resolved_name = get_country_name(code.upper(), region_df)
if resolved_name.upper() != code.upper():
matched_countries.append(resolved_name)
# Format the year range
start_year_str = extract_year(start_year) if start_year else "Unknown"
end_year_str = extract_year(end_year) if end_year else "Unknown"
formatted_project_budget = format_currency(total_project)
formatted_total_volume = format_currency(total_volume)
# Build the final string
if matched_countries:
additional_text = (
f"**{', '.join(matched_countries)}**, commissioned by **{client_name}**\n"
f"Projekt duration **{start_year_str}-{end_year_str}**\n"
f"Budget: Project: **{formatted_project_budget}**, Total volume: **{formatted_total_volume}**\n"
f"Country: **{', '.join(matched_countries)}**"
)
else:
additional_text = (
f"Commissioned by **{client_name}**\n"
f"Projekt duration **{start_year_str}-{end_year_str}**\n"
f"Budget: Project: **{formatted_project_budget}**, Total volume: **{formatted_total_volume}**\n"
f"Country: **{', '.join(c_list) if c_list else 'Unknown'}**"
)
st.markdown(additional_text)
st.divider()
# for i in results:
# st.subheader(str(i.metadata['id'])+":"+str(i.metadata['title_main']))
# st.caption(f"Status:{str(i.metadata['status'])}, Country:{str(i.metadata['country_name'])}")
# st.write(i.page_content)
# st.divider()