AnshulS's picture
Update app.py
59a4ae8 verified
import pandas as pd
import gradio as gr
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
import traceback
from retriever import get_relevant_passages
from reranker import rerank
# Create FastAPI app
app = FastAPI(root_path="")
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# === Load and Clean CSV ===
def clean_df(df):
df = df.copy()
second_col = df.iloc[:, 2].astype(str)
if second_col.str.contains('http').any() or second_col.str.contains('www').any():
df["url"] = second_col
else:
df["url"] = "https://www.shl.com" + second_col.str.replace(r'^(?!/)', '/', regex=True)
df["remote_support"] = df.iloc[:, 3].map(lambda x: "Yes" if x == "T" else "No")
df["adaptive_support"] = df.iloc[:, 4].map(lambda x: "Yes" if x == "T" else "No")
df["test_type"] = df.iloc[:, 5].apply(lambda x: eval(x) if isinstance(x, str) else x)
df["description"] = df.iloc[:, 6]
df["duration"] = pd.to_numeric(df.iloc[:, 9].astype(str).str.extract(r'(\d+)')[0], errors='coerce')
return df[["url", "adaptive_support", "remote_support", "description", "duration", "test_type"]]
try:
df = pd.read_csv("assesments.csv", encoding='utf-8')
df_clean = clean_df(df)
print(f"Successfully loaded {len(df_clean)} assessments")
except Exception as e:
print(f"Error loading data: {e}")
df_clean = pd.DataFrame(columns=["url", "adaptive_support", "remote_support", "description", "duration", "test_type"])
# === Utility ===
def validate_and_fix_urls(candidates):
for candidate in candidates:
if not isinstance(candidate, dict):
continue
if 'url' not in candidate or not candidate['url']:
candidate['url'] = 'https://www.shl.com/missing-url'
continue
url = str(candidate['url'])
if url.isdigit():
candidate['url'] = f"https://www.shl.com/{url}"
continue
if not url.startswith(('http://', 'https://')):
candidate['url'] = f"https://www.shl.com{url}" if url.startswith('/') else f"https://www.shl.com/{url}"
return candidates
# === Recommendation Logic ===
def recommend(query):
if not query or not query.strip():
return {"error": "Please enter a job description"}
try:
top_k_df = get_relevant_passages(query, df_clean, top_k=20)
if top_k_df.empty:
return {"error": "No matching assessments found"}
top_k_df['test_type'] = top_k_df['test_type'].apply(
lambda x: x if isinstance(x, list) else
(eval(x) if isinstance(x, str) and x.startswith('[') else [str(x)])
)
top_k_df['duration'] = top_k_df['duration'].fillna(-1).astype(int)
top_k_df.loc[top_k_df['duration'] == -1, 'duration'] = None
candidates = top_k_df.to_dict(orient="records")
candidates = validate_and_fix_urls(candidates)
result = rerank(query, candidates)
if 'recommended_assessments' in result:
result['recommended_assessments'] = validate_and_fix_urls(result['recommended_assessments'])
return result
except Exception as e:
print(traceback.format_exc())
return {"error": f"Error processing request: {str(e)}"}
# === FastAPI Endpoints ===
@app.get("/health")
async def health():
return JSONResponse(content={"status": "healthy"}, status_code=200)
@app.post("/recommend")
async def recommend_api(request: Request):
try:
data = await request.json()
query = data.get("query", "").strip()
if not query:
return JSONResponse(content={"error": "Missing query"}, status_code=400)
result = recommend(query)
return JSONResponse(content=result, status_code=200)
except Exception as e:
return JSONResponse(content={"error": str(e)}, status_code=500)
# Create a Gradio interface
demo = gr.Interface(
fn=recommend,
inputs=gr.Textbox(
label="Enter Job Description",
lines=4,
placeholder="Paste a job description here..."
),
outputs=gr.JSON(label="Recommended Assessments"),
title="SHL Assessment Recommender",
description="Paste a job description to get the most relevant SHL assessments.",
analytics_enabled=False,
)
# This is the pattern for Gradio 5.x
#app = gr.mount_gradio_app(app, demo, path="/")
app = gr.mount_gradio_app(
app,
demo,
path="/",
app_kwargs={
"ssl_verify": False, # Disable SSL verification for HF Spaces
"show_error": True, # Show detailed errors
}
)
# Entry point
if __name__ == "__main__":
import uvicorn
uvicorn.run(
app,
host="0.0.0.0",
port=7860,
log_level="info",
proxy_headers=True, # Process forwarded headers
forwarded_allow_ips="*" # Trust forwarded headers from any IP
)
#uvicorn.run(app, host="0.0.0.0", port=7860)
#app