Your commit message
Browse files
app.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Install required packages (for Colab/CLI)
|
2 |
+
# !pip install -U bitsandbytes langchain pypdf peft transformers accelerate datasets langchain-community faiss-cpu gradio
|
3 |
+
!pip install gradio langchain transformers peft accelerate bitsandbytes faiss-cpu pypdf langchain-community
|
4 |
+
|
5 |
+
import gradio as gr
|
6 |
+
import torch
|
7 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline as hf_pipeline
|
8 |
+
from peft import PeftModel
|
9 |
+
|
10 |
+
from langchain_community.document_loaders import PyPDFLoader
|
11 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
12 |
+
from langchain.embeddings import HuggingFaceEmbeddings
|
13 |
+
from langchain.vectorstores import FAISS
|
14 |
+
from langchain.chains import RetrievalQA
|
15 |
+
from langchain_community.llms import HuggingFacePipeline
|
16 |
+
|
17 |
+
import tempfile
|
18 |
+
import os
|
19 |
+
|
20 |
+
# Load model and tokenizer
|
21 |
+
def load_llm():
|
22 |
+
base_model = "HuggingFaceTB/SmolLM2-360M"
|
23 |
+
finetuned_dir = "./smollm2-finetuned-lora1"
|
24 |
+
|
25 |
+
model = AutoModelForCausalLM.from_pretrained(
|
26 |
+
base_model,
|
27 |
+
device_map="cpu", # Use "auto" if GPU is available
|
28 |
+
torch_dtype=torch.float32
|
29 |
+
)
|
30 |
+
model = PeftModel.from_pretrained(model, finetuned_dir)
|
31 |
+
model.eval()
|
32 |
+
|
33 |
+
tokenizer = AutoTokenizer.from_pretrained(finetuned_dir, use_fast=False)
|
34 |
+
tokenizer.pad_token = tokenizer.eos_token
|
35 |
+
|
36 |
+
hf_pipe = hf_pipeline(
|
37 |
+
task="text-generation",
|
38 |
+
model=model,
|
39 |
+
tokenizer=tokenizer,
|
40 |
+
max_new_tokens=200,
|
41 |
+
do_sample=True,
|
42 |
+
temperature=0.7,
|
43 |
+
top_p=0.9,
|
44 |
+
pad_token_id=tokenizer.eos_token_id
|
45 |
+
)
|
46 |
+
|
47 |
+
return HuggingFacePipeline(pipeline=hf_pipe)
|
48 |
+
|
49 |
+
# PDF β Chunks β Vectorstore β RAG
|
50 |
+
def process_pdf(pdf_file):
|
51 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp:
|
52 |
+
tmp.write(pdf_file.read())
|
53 |
+
tmp_path = tmp.name
|
54 |
+
|
55 |
+
loader = PyPDFLoader(tmp_path)
|
56 |
+
documents = loader.load()
|
57 |
+
|
58 |
+
splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
|
59 |
+
chunks = splitter.split_documents(documents)
|
60 |
+
|
61 |
+
embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
|
62 |
+
vectordb = FAISS.from_documents(chunks, embeddings)
|
63 |
+
|
64 |
+
retriever = vectordb.as_retriever(search_kwargs={"k": 4})
|
65 |
+
chain = RetrievalQA.from_chain_type(
|
66 |
+
llm=llm,
|
67 |
+
retriever=retriever,
|
68 |
+
return_source_documents=True
|
69 |
+
)
|
70 |
+
|
71 |
+
# Clean up temp file
|
72 |
+
os.unlink(tmp_path)
|
73 |
+
|
74 |
+
return chain
|
75 |
+
|
76 |
+
# Gradio states
|
77 |
+
llm = load_llm()
|
78 |
+
qa_chain = None
|
79 |
+
|
80 |
+
def upload_and_prepare(pdf_file):
|
81 |
+
global qa_chain
|
82 |
+
qa_chain = process_pdf(pdf_file)
|
83 |
+
return "β
PDF processed. You can now chat."
|
84 |
+
|
85 |
+
def ask_question(user_query):
|
86 |
+
if qa_chain is None:
|
87 |
+
return "β οΈ Please upload and process a PDF first."
|
88 |
+
result = qa_chain({"query": user_query})
|
89 |
+
return result["result"]
|
90 |
+
|
91 |
+
# Gradio UI
|
92 |
+
with gr.Blocks() as demo:
|
93 |
+
gr.Markdown("## π Chat with your PDF (SmolLM2 + LangChain + LoRA)")
|
94 |
+
|
95 |
+
with gr.Row():
|
96 |
+
pdf_input = gr.File(label="Upload a PDF", file_types=[".pdf"])
|
97 |
+
process_button = gr.Button("Process PDF")
|
98 |
+
|
99 |
+
status_output = gr.Textbox(label="Status")
|
100 |
+
|
101 |
+
with gr.Row():
|
102 |
+
user_input = gr.Textbox(label="Ask a question")
|
103 |
+
ask_button = gr.Button("Ask")
|
104 |
+
|
105 |
+
answer_output = gr.Textbox(label="Answer")
|
106 |
+
|
107 |
+
process_button.click(fn=upload_and_prepare, inputs=pdf_input, outputs=status_output)
|
108 |
+
ask_button.click(fn=ask_question, inputs=user_input, outputs=answer_output)
|
109 |
+
|
110 |
+
# Launch the app
|
111 |
+
demo.launch()
|