|
from fastapi import FastAPI, HTTPException, APIRouter, Request
|
|
from motor.motor_asyncio import AsyncIOMotorClient
|
|
from pydantic import BaseModel
|
|
import requests
|
|
import spacy
|
|
import time
|
|
import torch
|
|
from transformers import AutoTokenizer
|
|
from adapters import AutoAdapterModel
|
|
from typing import Optional
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
|
|
nlp = spacy.load("en_core_sci_sm")
|
|
|
|
|
|
class ResearchQuery(BaseModel):
|
|
userId: str
|
|
topic: str
|
|
year: str
|
|
|
|
|
|
class PaperMetadata(BaseModel):
|
|
paperId: str
|
|
title: str
|
|
abstract: str
|
|
citationCount: int
|
|
influentialCitationCount: int
|
|
publicationDate: str
|
|
url: str
|
|
|
|
|
|
|
|
def extract_keywords(text):
|
|
doc = nlp(text.lower())
|
|
|
|
|
|
noun_chunks = [chunk.text for chunk in doc.noun_chunks]
|
|
|
|
|
|
individual_tokens = [
|
|
token.text
|
|
for token in doc
|
|
if token.pos_ in ["NOUN", "VERB"] and not token.is_stop
|
|
]
|
|
|
|
|
|
keywords = set(noun_chunks + individual_tokens)
|
|
cleaned_keywords = set()
|
|
for keyword in keywords:
|
|
|
|
if not any(keyword in chunk for chunk in noun_chunks if keyword != chunk):
|
|
cleaned_keywords.add(keyword)
|
|
|
|
return sorted(list(cleaned_keywords))
|
|
|
|
|
|
|
|
def construct_query(keywords):
|
|
query = " + ".join(keywords)
|
|
return query
|
|
|
|
|
|
|
|
def fetch_paper_ids(query, year):
|
|
search_url = "https://api.semanticscholar.org/graph/v1/paper/search/bulk"
|
|
search_params = {
|
|
"query": query,
|
|
"year": year,
|
|
"fields": "paperId",
|
|
}
|
|
response = requests.get(search_url, params=search_params)
|
|
|
|
if response.status_code == 200:
|
|
data = response.json()
|
|
papers = data.get("data", [])
|
|
paper_ids = [paper["paperId"] for paper in papers]
|
|
return paper_ids
|
|
else:
|
|
raise HTTPException(status_code=response.status_code, detail="Error fetching paper IDs")
|
|
|
|
|
|
|
|
def fetch_metadata(batch_ids):
|
|
graph_url = "https://api.semanticscholar.org/graph/v1/paper/batch"
|
|
metadata_params = {
|
|
"fields": "title,abstract,citationCount,influentialCitationCount,publicationDate,url"
|
|
}
|
|
|
|
attempt = 0
|
|
max_retries = 20
|
|
|
|
while attempt < max_retries:
|
|
response = requests.post(graph_url, json={"ids": batch_ids}, params=metadata_params)
|
|
|
|
if response.status_code == 200:
|
|
return response.json()
|
|
|
|
elif response.status_code == 429:
|
|
wait_time = 5
|
|
print(f"429 Too Many Requests. Retrying in {wait_time} seconds...")
|
|
time.sleep(wait_time)
|
|
attempt += 1
|
|
|
|
else:
|
|
raise HTTPException(status_code=response.status_code, detail="Error fetching metadata")
|
|
|
|
raise HTTPException(status_code=500, detail="Max retries reached while fetching metadata")
|
|
|
|
|
|
|
|
def clean_and_process_metadata(metadata_list):
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("allenai/specter2_base")
|
|
model = AutoAdapterModel.from_pretrained("allenai/specter2_base")
|
|
|
|
|
|
model.load_adapter("allenai/specter2", source="hf", load_as="proximity", set_active=True)
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
model.to(device)
|
|
|
|
cleaned_metadata = []
|
|
invalid_papers = []
|
|
|
|
for paper in metadata_list:
|
|
paper_id = paper.get("paperId")
|
|
title = paper.get("title")
|
|
abstract = paper.get("abstract")
|
|
|
|
|
|
if not (title and abstract) and not abstract:
|
|
invalid_papers.append(paper_id)
|
|
continue
|
|
|
|
|
|
text = title + tokenizer.sep_token + abstract if title and abstract else abstract
|
|
|
|
|
|
inputs = tokenizer(text, padding=True, truncation=True, return_tensors="pt", max_length=512)
|
|
inputs = {k: v.to(device) for k, v in inputs.items()}
|
|
|
|
with torch.no_grad():
|
|
output = model(**inputs)
|
|
embedding = output.last_hidden_state[:, 0, :].cpu().numpy().tolist()
|
|
|
|
|
|
paper["embedding"] = {"model": "specter2_proximity", "vector": embedding}
|
|
|
|
|
|
cleaned_metadata.append(paper)
|
|
|
|
return cleaned_metadata, invalid_papers
|
|
|
|
|
|
|
|
async def save_to_mongodb(userId: str, topic: str, year:str , metadata_list, request:Request):
|
|
|
|
collection = request.app.state.collection
|
|
|
|
await collection.update_one(
|
|
{"userId": userId, "topic": topic ,"year": year},
|
|
{"$set": {"papers": metadata_list}},
|
|
upsert=True
|
|
)
|
|
|
|
print(f"Saved {len(metadata_list)} papers for user '{userId}' and topic '{topic}'")
|
|
|
|
|
|
@router.get("/test")
|
|
def greet():
|
|
return {
|
|
"message":"helllo jessi"
|
|
}
|
|
|
|
|
|
|
|
@router.post("/analyze")
|
|
async def analyze(query: ResearchQuery,requests:Request):
|
|
|
|
keywords = extract_keywords(query.topic)
|
|
refined_query = construct_query(keywords)
|
|
|
|
print(f"\n🔍 Refined Query: {refined_query}")
|
|
|
|
|
|
paper_ids = fetch_paper_ids(refined_query, query.year)
|
|
print(f"\nFound {len(paper_ids)} papers for '{query.topic}' in {query.year}\n")
|
|
|
|
|
|
metadata_list = []
|
|
batch_size = 100
|
|
|
|
for i in range(0, len(paper_ids), batch_size):
|
|
batch_ids = paper_ids[i : i + batch_size]
|
|
metadata = fetch_metadata(batch_ids)
|
|
metadata_list.extend(metadata)
|
|
print(f"Retrieved {len(batch_ids)} papers' metadata...")
|
|
|
|
|
|
cleaned_metadata, invalid_papers = clean_and_process_metadata(metadata_list)
|
|
|
|
|
|
await save_to_mongodb(query.userId, query.topic, query.year, cleaned_metadata,requests )
|
|
|
|
return {
|
|
"message": f"Processed {len(cleaned_metadata)} papers and cleaned data.",
|
|
"invalid_papers_removed": len(invalid_papers),
|
|
}
|
|
|
|
|
|
|
|
|
|
class CheckDataRequest(BaseModel):
|
|
userId: str
|
|
topic: str
|
|
year: Optional[str] = None
|
|
|
|
|
|
|
|
@router.post("/check-data-exists/")
|
|
async def check_data_exists(request_data: CheckDataRequest, request:Request):
|
|
|
|
query = {
|
|
"userId": request_data.userId,
|
|
"topic": request_data.topic
|
|
}
|
|
|
|
|
|
if request_data.year:
|
|
query["year"] = request_data.year
|
|
|
|
collection = request.app.state.collection
|
|
|
|
document = await collection.find_one(query)
|
|
|
|
|
|
return {
|
|
"exists": document is not None,
|
|
"message": "Data found" if document else "Data not found"
|
|
}
|
|
|
|
|