Spaces:
Sleeping
Sleeping
File size: 5,273 Bytes
00134aa cddcd13 00134aa 2413d91 00134aa 00588f0 00134aa 0bef7bf 00134aa 00588f0 00134aa 0bef7bf 80eab7c 00134aa 87af990 00134aa 0bef7bf 80eab7c 00588f0 87af990 80eab7c 87af990 00588f0 87af990 00588f0 87af990 00588f0 87af990 535dc7a fe7cffb 535dc7a 87af990 00588f0 00134aa a61487c 00134aa 80eab7c c09090d 25beaa7 00134aa 87af990 00134aa 0bef7bf 886aaf2 0bef7bf 80eab7c 87af990 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
import sys
import os
import shutil
import warnings
import spaces
from threading import Thread
from transformers import TextIteratorStreamer
from functools import partial
from huggingface_hub import snapshot_download
import gradio as gr
import torch
import numpy as np
from model import Rank1
import math
print(f"NumPy version: {np.__version__}")
print(f"PyTorch version: {torch.__version__}")
# Suppress CUDA initialization warning
warnings.filterwarnings("ignore", category=UserWarning, message="Can't initialize NVML")
MODEL_PATH = None
reranker = None
@spaces.GPU
def process_input(query: str, passage: str) -> tuple[str, str, str]:
"""Process input through the reranker and return formatted outputs."""
global MODEL_PATH
global reranker
prompt = f"Determine if the following passage is relevant to the query. Answer only with 'true' or 'false'.\n" \
f"Query: {query}\n" \
f"Passage: {passage}\n" \
"<think>"
reranker.model = reranker.model.to("cuda")
inputs = reranker.tokenizer(
prompt,
return_tensors="pt",
truncation=True,
max_length=reranker.context_size
).to("cuda")
streamer = TextIteratorStreamer(
reranker.tokenizer,
skip_prompt=True,
skip_special_tokens=False
)
current_text = "<think>"
generation_output = None
def generate_with_output():
nonlocal generation_output
generation_output = reranker.model.generate(
**inputs,
generation_config=reranker.generation_config,
stopping_criteria=reranker.stopping_criteria,
return_dict_in_generate=True,
output_scores=True,
streamer=streamer
)
thread = Thread(target=generate_with_output)
thread.start()
# Stream tokens as they're generated
for new_text in streamer:
current_text += new_text
yield (
"Processing...",
"Processing...",
current_text
)
thread.join()
# Add the stopping sequence and calculate final scores
if "</think>" not in current_text:
current_text += "\n" + reranker.stopping_criteria[0].matched_sequence
with torch.no_grad():
final_scores = generation_output.scores[-1][0]
true_logit = final_scores[reranker.true_token].item()
false_logit = final_scores[reranker.false_token].item()
true_score = math.exp(true_logit)
false_score = math.exp(false_logit)
score = true_score / (true_score + false_score)
yield (
score > 0.5,
score,
current_text
)
# Example inputs
examples = [
[
"What movies were directed by James Cameron?",
"Avatar: The Way of Water is a 2022 American epic science fiction film directed by James Cameron.",
],
[
"What movies were directed by James Cameron?",
"Common symptoms of COVID-19 include fever, cough, fatigue, loss of taste or smell, and difficulty breathing.",
]
]
theme = gr.themes.Soft(
primary_hue="indigo",
font=["Inter", "ui-sans-serif", "system-ui", "sans-serif"],
neutral_hue="slate",
radius_size="lg",
)
with gr.Blocks(theme=theme, css=".red-text { color: red; }") as demo:
gr.Markdown("# Rank1: Test Time Compute in Reranking")
gr.HTML('NOTE: for demo purposes this is a <span style="color: red;">quantized</span> model limited to a <span style="color: red;">1024</span> context length. HF spaces cannot use vLLM so this is <span style="color: red;">significantly slower</span>')
gr.HTML('π Paper Link: <a href="https://arxiv.org/abs/2502.18418" target="_blank">https://arxiv.org/abs/2502.18418</a>')
with gr.Row():
with gr.Column():
query_input = gr.Textbox(
label="Query",
placeholder="Enter your search query here",
lines=2
)
passage_input = gr.Textbox(
label="Passage",
placeholder="Enter the passage to check for relevance",
lines=6
)
submit_button = gr.Button("Check Relevance")
with gr.Column():
relevance_output = gr.Textbox(label="Relevance")
confidence_output = gr.Textbox(label="Confidence")
reasoning_output = gr.Textbox(
label="Model Reasoning",
lines=10,
interactive=False
)
gr.Examples(
examples=examples,
inputs=[query_input, passage_input],
outputs=[relevance_output, confidence_output, reasoning_output],
fn=process_input,
cache_examples=True,
)
submit_button.click(
fn=process_input,
inputs=[query_input, passage_input],
outputs=[relevance_output, confidence_output, reasoning_output],
api_name="predict",
queue=True
)
if __name__ == "__main__":
# download model first, so we don't have to wait for it
MODEL_PATH = snapshot_download(
repo_id="orionweller/rank1-7b-awq",
)
print(f"Downloaded model to: {MODEL_PATH}")
reranker = Rank1(model_name_or_path=MODEL_PATH)
demo.launch(share=False) |