limyehji commited on
Commit
e2d169f
ยท
verified ยท
1 Parent(s): 9cd80c5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -2
app.py CHANGED
@@ -6,11 +6,29 @@ from datasets import load_dataset
6
 
7
  app = FastAPI()
8
 
9
- # 1. ์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ ๋กœ๋“œ
10
  embed_model = SentenceTransformer("all-MiniLM-L6-v2")
11
 
12
- # 2. Hugging Face์—์„œ MedRAG ๋ฐ์ดํ„ฐ์…‹ ๋กœ๋“œ
13
  dataset = load_dataset("MedRAG/textbooks", split="train", streaming=True)
14
 
15
  # 3. ๋ฐ์ดํ„ฐ ๋ณ€ํ™˜
16
  texts = [entry["content"] for entry in dataset] # "content" ํ•„๋“œ ํ™œ์šฉ
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  app = FastAPI()
8
 
9
+ # 1. MiniLM ๋ชจ๋ธ ๋กœ๋“œ (๋ฐ์ดํ„ฐ ๋ฒกํ„ฐํ™”)
10
  embed_model = SentenceTransformer("all-MiniLM-L6-v2")
11
 
12
+ # 2. MedRAG ๋ฐ์ดํ„ฐ์…‹ ๋กœ๋“œ
13
  dataset = load_dataset("MedRAG/textbooks", split="train", streaming=True)
14
 
15
  # 3. ๋ฐ์ดํ„ฐ ๋ณ€ํ™˜
16
  texts = [entry["content"] for entry in dataset] # "content" ํ•„๋“œ ํ™œ์šฉ
17
+
18
+ # 4. ๋ฒกํ„ฐ ์ž„๋ฒ ๋”ฉ ์ƒ์„ฑ ๋ฐ FAISS์— ์ €์žฅ
19
+ vectors = embed_model.encode(texts)
20
+ dimension = vectors.shape[1] # ์ž„๋ฒ ๋”ฉ ์ฐจ์›
21
+ index = faiss.IndexFlatL2(dimension) # L2 ๊ฑฐ๋ฆฌ ๊ธฐ๋ฐ˜ FAISS ์ธ๋ฑ์Šค ์ƒ์„ฑ
22
+ index.add(np.array(vectors)) # FAISS์— ๋ฒกํ„ฐ ์ถ”๊ฐ€
23
+
24
+ # 5. ๊ฒ€์ƒ‰ API (GPTs์—์„œ ํ˜ธ์ถœ ๊ฐ€๋Šฅ)
25
+ @app.get("/search")
26
+ def search(query: str):
27
+ """ ์‚ฌ์šฉ์ž์˜ ์ฟผ๋ฆฌ๋ฅผ ๋ฒกํ„ฐ ๋ณ€ํ™˜ ํ›„, FAISS์—์„œ ๊ฒ€์ƒ‰ํ•˜์—ฌ ๊ด€๋ จ ๋ฌธ์„œ ๋ฐ˜ํ™˜ """
28
+ query_vector = embed_model.encode([query])
29
+ query_vector = np.array(query_vector, dtype=np.float32) # FAISS ํ˜ธํ™˜
30
+
31
+ _, I = index.search(query_vector, k=3) # FAISS๋กœ Top-3 ๊ฒ€์ƒ‰
32
+ results = [texts[i] for i in I[0]] # ๊ฒ€์ƒ‰๋œ ๋ฌธ์„œ ๋ฐ˜ํ™˜
33
+
34
+ return {"retrieved_docs": results}