|
import streamlit as st |
|
import requests |
|
import os |
|
import time |
|
|
|
|
|
HF_API_KEY = os.getenv("HF_API_KEY") |
|
|
|
|
|
IMG2TEXT_API = "https://api-inference.huggingface.co/models/nlpconnect/vit-gpt2-image-captioning" |
|
CHAT_API = "https://api-inference.huggingface.co/models/facebook/blenderbot-3B" |
|
HEADERS = {"Authorization": f"Bearer {HF_API_KEY}"} |
|
|
|
|
|
st.title("Multimodal Chatbot") |
|
|
|
|
|
if "messages" not in st.session_state: |
|
st.session_state.messages = [] |
|
|
|
initial_message = "Hello! I'm a chatbot. You can upload an image or ask me anything to get started!" |
|
st.session_state.messages.append({"role": "assistant", "content": initial_message}) |
|
|
|
|
|
for msg in st.session_state.messages: |
|
with st.chat_message(msg["role"]): |
|
st.write(msg["content"]) |
|
|
|
|
|
uploaded_file = st.file_uploader("Upload an image...", type=["jpg", "png", "jpeg"]) |
|
|
|
|
|
user_input = st.chat_input("Ask about this image or anything...") |
|
image_caption = None |
|
|
|
|
|
if uploaded_file: |
|
|
|
if uploaded_file.type not in ["image/jpeg", "image/png"]: |
|
st.error("⚠️ Please upload a valid JPG or PNG image.") |
|
else: |
|
|
|
img_bytes = uploaded_file.read() |
|
st.session_state.messages.append({"role": "user", "content": "[Image Uploaded]"}) |
|
with st.chat_message("user"): |
|
st.image(img_bytes, caption="Uploaded Image", use_column_width=True) |
|
|
|
|
|
max_retries = 3 |
|
for i in range(max_retries): |
|
response = requests.post( |
|
IMG2TEXT_API, |
|
headers={ |
|
"Authorization": f"Bearer {HF_API_KEY}", |
|
"Content-Type": "application/octet-stream", |
|
}, |
|
data=img_bytes |
|
) |
|
|
|
if response.status_code == 200: |
|
try: |
|
res_json = response.json() |
|
|
|
if isinstance(res_json, list) and len(res_json) > 0: |
|
image_caption = res_json[0].get("generated_text", "⚠️ No caption generated.") |
|
elif isinstance(res_json, dict) and "generated_text" in res_json: |
|
image_caption = res_json["generated_text"] |
|
|
|
if image_caption: |
|
st.session_state.image_caption = image_caption |
|
bot_context = ( |
|
f"Consider this image: {image_caption}. Please provide a relevant and engaging response to the image." |
|
) |
|
payload = {"inputs": bot_context} |
|
|
|
|
|
bot_response = requests.post(CHAT_API, headers=HEADERS, json=payload) |
|
|
|
if bot_response.status_code == 200: |
|
res_json = bot_response.json() |
|
|
|
if isinstance(res_json, list) and len(res_json) > 0: |
|
bot_reply = res_json[0].get("generated_text", "I received your image. What would you like to ask about it?") |
|
elif isinstance(res_json, dict) and "generated_text" in res_json: |
|
bot_reply = res_json["generated_text"] |
|
else: |
|
bot_reply = "I received your image. What would you like to ask about it?" |
|
else: |
|
bot_reply = "I received your image. What would you like to ask about it?" |
|
|
|
|
|
st.session_state.messages.append({"role": "assistant", "content": bot_reply}) |
|
with st.chat_message("assistant"): |
|
st.write(bot_reply) |
|
uploaded_file = None |
|
break |
|
else: |
|
st.error("⚠️ Unexpected response format from image captioning API.") |
|
break |
|
except (KeyError, IndexError, TypeError) as e: |
|
st.error(f"⚠️ Error: Unable to generate caption. Details: {e}") |
|
break |
|
elif response.status_code == 503: |
|
st.warning(f"⏳ Model warming up... Retrying in 5 seconds. Attempt {i+1}/{max_retries}") |
|
time.sleep(5) |
|
else: |
|
st.error(f"⚠️ Image API Error: {response.status_code} - {response.text}") |
|
break |
|
|
|
|
|
if user_input: |
|
combined_input = user_input |
|
|
|
|
|
if "image_caption" in st.session_state and st.session_state.image_caption: |
|
combined_input = f"Image context: {st.session_state.image_caption}. {user_input}" |
|
|
|
|
|
st.session_state.messages.append({"role": "user", "content": user_input}) |
|
with st.chat_message("user"): |
|
st.write(user_input) |
|
|
|
|
|
payload = {"inputs": combined_input} |
|
max_retries = 3 |
|
for i in range(max_retries): |
|
response = requests.post(CHAT_API, headers=HEADERS, json=payload) |
|
|
|
if response.status_code == 200: |
|
try: |
|
res_json = response.json() |
|
|
|
|
|
if isinstance(res_json, dict) and "generated_text" in res_json: |
|
bot_reply = res_json["generated_text"] |
|
break |
|
|
|
|
|
elif isinstance(res_json, list) and len(res_json) > 0 and "generated_text" in res_json[0]: |
|
bot_reply = res_json[0]["generated_text"] |
|
break |
|
else: |
|
st.error("⚠️ Unexpected response format from chatbot API.") |
|
bot_reply = "⚠️ Unable to generate a response." |
|
break |
|
except (KeyError, TypeError, IndexError): |
|
bot_reply = "⚠️ Error: Unable to generate response." |
|
break |
|
elif response.status_code == 503: |
|
st.warning(f"⏳ Model warming up... Retrying in 5 seconds. Attempt {i+1}/{max_retries}") |
|
time.sleep(5) |
|
else: |
|
bot_reply = f"⚠️ Chatbot Error {response.status_code}: {response.text}" |
|
break |
|
|
|
|
|
st.session_state.messages.append({"role": "assistant", "content": bot_reply}) |
|
with st.chat_message("assistant"): |
|
st.write(bot_reply) |
|
|
|
|
|
if st.button("Clear Chat"): |
|
st.session_state.messages = [] |
|
st.experimental_rerun() |