R-PRM-Demo / app.py
kevinpro's picture
Update app.py
6b84df3 verified
raw
history blame
2.41 kB
import gradio as gr
from functools import lru_cache
import os
import spaces
import gradio as gr
from transformers import AutoTokenizer, AutoModel,AutoModelForCausalLM
import torch
# 假设openai_client已定义,例如:
device = "cuda"
MODEL_NAME = "kevinpro/R-PRM-7B-DPO"
print("Start dowload")
def load_model():
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME,torch_dtype="bfloat16").to(device)
print(f"Model loaded in {device}")
return model
model = load_model()
print("Ednd dowload")
# Loading the tokenizer once, because re-loading it takes about 1.5 seconds each time
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
@lru_cache(maxsize=100)
def translate(text: str):
return _translate(text)
# Only assign GPU if cache not used
@spaces.GPU
def _translate(text: str):
input_tokens = (
tokenizer(text, return_tensors="pt")
.input_ids[0]
.cpu()
.numpy()
.tolist()
)
translated_chunk = model.generate(
input_ids=torch.tensor([input_tokens]).to(device),
max_length=len(input_tokens) + 2048,
num_return_sequences=1,
)
full_output = tokenizer.decode(translated_chunk[0], skip_special_tokens=True).strip()
print(full_output)
return full_output
description = """
<div style="text-align: center;">
<h1 style="color: #0077be; font-size: 3em;">R-PRM, powered by NJUNLP</h1>
<h3 style="font-size: 3em;">🚀 We introduce Reasoning-Driven Process Reward Modeling (R-PRM), a novel approach that enhances LLMs' ability to evaluate mathematical reasoning step-by-step. By leveraging stronger LLMs to generate seed data, optimizing preferences without additional annotations, and scaling inference-time computation, R-PRM delivers comprehensive, transparent, and robust assessments of reasoning processes.</h3>
</div>
"""
examples_inputs = [["test"]]
with gr.Blocks() as demo:
gr.Markdown(description)
with gr.Row():
input_text = gr.Textbox(label="Input Text", lines=6)
with gr.Row():
btn = gr.Button("Translate text")
with gr.Row():
output = gr.Textbox(label="Output Text", lines=6)
btn.click(
translate,
inputs=[input_text],
outputs=output,
)
examples = gr.Examples(examples=examples_inputs,inputs=[input_text], fn=translate, outputs=output, cache_examples=True)
print("Prepared")
demo.launch()