File size: 7,372 Bytes
4344f16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cedbd37
4344f16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import streamlit as st
import requests
import os
import time

# Load Hugging Face API key
HF_API_KEY = os.getenv("HF_API_KEY")

# Define API URLs
IMG2TEXT_API = "https://api-inference.huggingface.co/models/nlpconnect/vit-gpt2-image-captioning"
CHAT_API = "https://api-inference.huggingface.co/models/facebook/blenderbot-3B"
HEADERS = {"Authorization": f"Bearer {HF_API_KEY}"}

# App Title
st.title("Multimodal Chatbot")

# Initialize chat history
if "messages" not in st.session_state:
    st.session_state.messages = []
    # Add initial bot welcome message
    initial_message = "Hello! I'm a chatbot. You can upload an image or ask me anything to get started!"
    st.session_state.messages.append({"role": "assistant", "content": initial_message})

# Display chat history
for msg in st.session_state.messages:
    with st.chat_message(msg["role"]):
        st.write(msg["content"])

# Image upload
uploaded_file = st.file_uploader("Upload an image...", type=["jpg", "png", "jpeg"])

# User input
user_input = st.chat_input("Ask about this image or anything...")
image_caption = None

# Process image if uploaded
if uploaded_file:
    # Check image type
    if uploaded_file.type not in ["image/jpeg", "image/png"]:
        st.error("⚠️ Please upload a valid JPG or PNG image.")
    else:
        # Send image to Hugging Face image-to-text API with retries
        img_bytes = uploaded_file.read()
        st.session_state.messages.append({"role": "user", "content": "[Image Uploaded]"})
        with st.chat_message("user"):
            st.image(img_bytes, caption="Uploaded Image", use_column_width=True)
            # st.write(f"**Image to text context generated:** {image_caption}") fix plz
        
        max_retries = 3
        for i in range(max_retries):
            response = requests.post(
                IMG2TEXT_API,
                headers={
                    "Authorization": f"Bearer {HF_API_KEY}",
                    "Content-Type": "application/octet-stream",
                },
                data=img_bytes  # Send raw image data
            )

            if response.status_code == 200:
                try:
                    res_json = response.json()
                    # Check for list format and dictionary format
                    if isinstance(res_json, list) and len(res_json) > 0:
                        image_caption = res_json[0].get("generated_text", "⚠️ No caption generated.")
                    elif isinstance(res_json, dict) and "generated_text" in res_json:
                        image_caption = res_json["generated_text"]

                    if image_caption:
                        st.session_state.image_caption = image_caption
                        bot_context = (
                            f"Consider this image: {image_caption}. Please provide a relevant and engaging response to the image."
                        )
                        payload = {"inputs": bot_context}

                        # Send context to chatbot
                        bot_response = requests.post(CHAT_API, headers=HEADERS, json=payload)

                        if bot_response.status_code == 200:
                            res_json = bot_response.json()
                            # Check if the response is a list or dictionary
                            if isinstance(res_json, list) and len(res_json) > 0:
                                bot_reply = res_json[0].get("generated_text", "I received your image. What would you like to ask about it?")
                            elif isinstance(res_json, dict) and "generated_text" in res_json:
                                bot_reply = res_json["generated_text"]
                            else:
                                bot_reply = "I received your image. What would you like to ask about it?"
                        else:
                            bot_reply = "I received your image. What would you like to ask about it?"

                        # Append chatbot's generated response
                        st.session_state.messages.append({"role": "assistant", "content": bot_reply})
                        with st.chat_message("assistant"):
                            st.write(bot_reply)
                        uploaded_file = None  # Clear image after processing
                        break  # Successful, no need to retry
                    else:
                        st.error("⚠️ Unexpected response format from image captioning API.")
                        break
                except (KeyError, IndexError, TypeError) as e:
                    st.error(f"⚠️ Error: Unable to generate caption. Details: {e}")
                    break
            elif response.status_code == 503:
                st.warning(f"⏳ Model warming up... Retrying in 5 seconds. Attempt {i+1}/{max_retries}")
                time.sleep(5)  # Wait before retrying
            else:
                st.error(f"⚠️ Image API Error: {response.status_code} - {response.text}")
                break

# Process user input if provided
if user_input:
    combined_input = user_input

    # Merge image caption with user query if an image was uploaded
    if "image_caption" in st.session_state and st.session_state.image_caption:
        combined_input = f"Image context: {st.session_state.image_caption}. {user_input}"

    # Append user message
    st.session_state.messages.append({"role": "user", "content": user_input})
    with st.chat_message("user"):
        st.write(user_input)

    # Send combined input to chatbot with retries
    payload = {"inputs": combined_input}
    max_retries = 3
    for i in range(max_retries):
        response = requests.post(CHAT_API, headers=HEADERS, json=payload)

        if response.status_code == 200:
            try:
                res_json = response.json()
                
                # If it's a dictionary and contains 'generated_text'
                if isinstance(res_json, dict) and "generated_text" in res_json:
                    bot_reply = res_json["generated_text"]
                    break  # Successful, no need to retry
                
                # If response is a list (some models return list format)
                elif isinstance(res_json, list) and len(res_json) > 0 and "generated_text" in res_json[0]:
                    bot_reply = res_json[0]["generated_text"]
                    break
                else:
                    st.error("⚠️ Unexpected response format from chatbot API.")
                    bot_reply = "⚠️ Unable to generate a response."
                    break
            except (KeyError, TypeError, IndexError):
                bot_reply = "⚠️ Error: Unable to generate response."
                break
        elif response.status_code == 503:
            st.warning(f"⏳ Model warming up... Retrying in 5 seconds. Attempt {i+1}/{max_retries}")
            time.sleep(5)  # Wait before retrying
        else:
            bot_reply = f"⚠️ Chatbot Error {response.status_code}: {response.text}"
            break

    # Append bot response
    st.session_state.messages.append({"role": "assistant", "content": bot_reply})
    with st.chat_message("assistant"):
        st.write(bot_reply)

# Clear button to reset chat
if st.button("Clear Chat"):
    st.session_state.messages = []
    st.experimental_rerun()