Spaces:
Sleeping
Sleeping
File size: 1,988 Bytes
6ddeb4a 0aef548 d89669b 6ddeb4a d89669b 6ddeb4a d89669b 6ddeb4a d89669b 6ddeb4a 7dcca7a d89669b 7dcca7a d89669b 7dcca7a d89669b 6ddeb4a d89669b 6ddeb4a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
import streamlit as st
import os
from dotenv import load_dotenv
from transformers import (
pipeline,
AutoConfig,
AutoTokenizer,
AutoModelForCausalLM
)
# Load environment variables from .env (if you’re using one)
load_dotenv()
st.set_page_config(page_title="Educational Chatbot")
st.title("🎓 Educational Chatbot")
@st.cache_resource(show_spinner=False)
def load_model():
# 1. Load the remote config (with trust_remote_code)
config = AutoConfig.from_pretrained(
"deepseek-ai/DeepSeek-R1",
trust_remote_code=True
)
# 2. Remove unsupported fp8 quantization
if hasattr(config, "quantization_config"):
config.quantization_config = None
# 3. Load tokenizer and model with patched config
tokenizer = AutoTokenizer.from_pretrained(
"deepseek-ai/DeepSeek-R1",
trust_remote_code=True
)
model = AutoModelForCausalLM.from_pretrained(
"deepseek-ai/DeepSeek-R1",
trust_remote_code=True,
config=config
)
# 4. Build the text-generation pipeline
gen = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
trust_remote_code=True,
device_map="auto" # or remove for CPU-only
)
return gen
# Load the model once
generator = load_model()
# Initialize chat history
if "history" not in st.session_state:
st.session_state.history = []
# User input box
user_input = st.text_input("Ask me anything:")
# When user enters a question
if user_input:
try:
outputs = generator(user_input, return_full_text=False)
reply = outputs[0]["generated_text"].strip()
st.session_state.history.append(("You", user_input))
st.session_state.history.append(("Bot", reply))
except Exception as e:
st.session_state.history.append(("Bot", f"⚠️ Error: {e}"))
# Display chat history
for sender, msg in reversed(st.session_state.history):
st.markdown(f"**{sender}:** {msg}")
|