import gradio as gr from functools import lru_cache import os import spaces import gradio as gr from transformers import AutoTokenizer, AutoModel,AutoModelForCausalLM import torch # 假设openai_client已定义,例如: device = "cuda" MODEL_NAME = "kevinpro/R-PRM-7B-DPO" print("Start dowload") def load_model(): model = AutoModelForCausalLM.from_pretrained(MODEL_NAME,torch_dtype="bfloat16").to(device) print(f"Model loaded in {device}") return model model = load_model() print("Ednd dowload") # Loading the tokenizer once, because re-loading it takes about 1.5 seconds each time tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) @lru_cache(maxsize=100) def translate(text: str): return _translate(text) # Only assign GPU if cache not used @spaces.GPU def _translate(text: str): input_tokens = ( tokenizer(text, return_tensors="pt") .input_ids[0] .cpu() .numpy() .tolist() ) translated_chunk = model.generate( input_ids=torch.tensor([input_tokens]).to(device), max_length=len(input_tokens) + 2048, num_return_sequences=1, ) full_output = tokenizer.decode(translated_chunk[0], skip_special_tokens=True).strip() print(full_output) return full_output description = """

R-PRM, powered by NJUNLP

🚀 We introduce Reasoning-Driven Process Reward Modeling (R-PRM), a novel approach that enhances LLMs' ability to evaluate mathematical reasoning step-by-step. By leveraging stronger LLMs to generate seed data, optimizing preferences without additional annotations, and scaling inference-time computation, R-PRM delivers comprehensive, transparent, and robust assessments of reasoning processes.

""" examples_inputs = [["test"]] with gr.Blocks() as demo: gr.Markdown(description) with gr.Row(): input_text = gr.Textbox(label="Input Text", lines=6) with gr.Row(): btn = gr.Button("Translate text") with gr.Row(): output = gr.Textbox(label="Output Text", lines=6) btn.click( translate, inputs=[input_text], outputs=output, ) examples = gr.Examples(examples=examples_inputs,inputs=[input_text], fn=translate, outputs=output, cache_examples=True) print("Prepared") demo.launch()