Raphaël Bournhonesque
commited on
Commit
·
c055452
1
Parent(s):
a85989c
add app
Browse files- app.py +133 -0
- requirements.txt +3 -0
app.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
|
3 |
+
from annotated_text import annotated_text
|
4 |
+
import requests
|
5 |
+
import streamlit as st
|
6 |
+
|
7 |
+
|
8 |
+
BARCODE_PATH_REGEX = re.compile(r"^(...)(...)(...)(.*)$")
|
9 |
+
|
10 |
+
|
11 |
+
def split_barcode(barcode: str) -> list[str]:
|
12 |
+
"""Split barcode in the same way as done by Product Opener to generate a
|
13 |
+
product image folder.
|
14 |
+
|
15 |
+
:param barcode: The barcode of the product. For the pro platform only,
|
16 |
+
it must be prefixed with the org ID using the format
|
17 |
+
`{ORG_ID}/{BARCODE}`
|
18 |
+
:raises ValueError: raise a ValueError if `barcode` is invalid
|
19 |
+
:return: a list containing the splitted barcode
|
20 |
+
"""
|
21 |
+
org_id = None
|
22 |
+
if "/" in barcode:
|
23 |
+
# For the pro platform, `barcode` is expected to be in the format
|
24 |
+
# `{ORG_ID}/{BARCODE}` (ex: `org-lea-nature/3307130803004`)
|
25 |
+
org_id, barcode = barcode.split("/", maxsplit=1)
|
26 |
+
|
27 |
+
if not barcode.isdigit():
|
28 |
+
raise ValueError(f"unknown barcode format: {barcode}")
|
29 |
+
|
30 |
+
match = BARCODE_PATH_REGEX.fullmatch(barcode)
|
31 |
+
|
32 |
+
splits = [x for x in match.groups() if x] if match else [barcode]
|
33 |
+
|
34 |
+
if org_id is not None:
|
35 |
+
# For the pro platform only, images and OCRs belonging to an org
|
36 |
+
# are stored in a folder named after the org for all its products, ex:
|
37 |
+
# https://images.pro.openfoodfacts.org/images/products/org-lea-nature/330/713/080/3004/1.jpg
|
38 |
+
splits.append(org_id)
|
39 |
+
|
40 |
+
return splits
|
41 |
+
|
42 |
+
|
43 |
+
def _generate_file_path(barcode: str, image_id: str, suffix: str):
|
44 |
+
splitted_barcode = split_barcode(barcode)
|
45 |
+
return f"/{'/'.join(splitted_barcode)}/{image_id}{suffix}"
|
46 |
+
|
47 |
+
|
48 |
+
def generate_ocr_path(barcode: str, image_id: str) -> str:
|
49 |
+
return _generate_file_path(barcode, image_id, ".json")
|
50 |
+
|
51 |
+
|
52 |
+
def generate_image_path(barcode: str, image_id: str) -> str:
|
53 |
+
return _generate_file_path(barcode, image_id, ".400.jpg")
|
54 |
+
|
55 |
+
|
56 |
+
@st.cache_data
|
57 |
+
def send_prediction_request(ocr_url: str):
|
58 |
+
return requests.get(
|
59 |
+
"https://robotoff.openfoodfacts.net/api/v1/predict/ingredient_list",
|
60 |
+
params={"ocr_url": ocr_url},
|
61 |
+
).json()
|
62 |
+
|
63 |
+
|
64 |
+
def get_product(barcode: str):
|
65 |
+
r = requests.get(f"https://world.openfoodfacts.org/api/v2/product/{barcode}")
|
66 |
+
|
67 |
+
if r.status_code == 404:
|
68 |
+
return None
|
69 |
+
|
70 |
+
return r.json()["product"]
|
71 |
+
|
72 |
+
|
73 |
+
def display_ner_tags(text: str, entities: list[dict]):
|
74 |
+
spans = []
|
75 |
+
previous_idx = 0
|
76 |
+
for entity in entities:
|
77 |
+
score = entity["score"]
|
78 |
+
lang = entity["lang"]["lang"]
|
79 |
+
start_idx = entity["start"]
|
80 |
+
end_idx = entity["end"]
|
81 |
+
spans.append(text[previous_idx:start_idx])
|
82 |
+
spans.append((text[start_idx:end_idx], f"score={score:.3f} | lang={lang}"))
|
83 |
+
previous_idx = end_idx
|
84 |
+
spans.append(text[previous_idx:])
|
85 |
+
annotated_text(spans)
|
86 |
+
|
87 |
+
|
88 |
+
def run(barcode: str, min_threshold: float = 0.5):
|
89 |
+
product = get_product(barcode)
|
90 |
+
|
91 |
+
if not product:
|
92 |
+
st.error(f"Product {barcode} not found")
|
93 |
+
return
|
94 |
+
|
95 |
+
images = product["images"]
|
96 |
+
for image_id, _ in images.items():
|
97 |
+
if not image_id.isdigit():
|
98 |
+
continue
|
99 |
+
|
100 |
+
ocr_path = generate_ocr_path(barcode, image_id)
|
101 |
+
ocr_url = f"https://static.openfoodfacts.org/images/products{ocr_path}"
|
102 |
+
prediction = send_prediction_request(ocr_url)
|
103 |
+
|
104 |
+
entities = prediction["entities"]
|
105 |
+
text = prediction["text"]
|
106 |
+
filtered_entities = [e for e in entities if e["score"] >= min_threshold]
|
107 |
+
|
108 |
+
if filtered_entities:
|
109 |
+
st.divider()
|
110 |
+
image_path = generate_image_path(barcode, image_id)
|
111 |
+
image_url = f"https://static.openfoodfacts.org/images/products{image_path}"
|
112 |
+
st.image(image_url)
|
113 |
+
display_ner_tags(text, filtered_entities)
|
114 |
+
|
115 |
+
|
116 |
+
query_params = st.experimental_get_query_params()
|
117 |
+
default_barcode = query_params["barcode"][0] if "barcode" in query_params else ""
|
118 |
+
|
119 |
+
st.title("Ingredient extraction demo")
|
120 |
+
st.markdown(
|
121 |
+
"This demo leverages the ingredient entity detection model, that takes the OCR text as input and predict ingredient lists."
|
122 |
+
)
|
123 |
+
barcode = st.text_input("barcode", help="Barcode of the product", value=default_barcode)
|
124 |
+
threshold = st.number_input(
|
125 |
+
"threshold",
|
126 |
+
help="Minimum threshold for entity predictions",
|
127 |
+
min_value=0.0,
|
128 |
+
max_value=1.0,
|
129 |
+
value=0.98,
|
130 |
+
)
|
131 |
+
|
132 |
+
if barcode:
|
133 |
+
run(barcode, threshold)
|
requirements.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
requests==2.28.1
|
2 |
+
streamlit==1.15.1
|
3 |
+
st-annotated-text==4.0.0
|