import random from gliner import GLiNER import gradio as gr from datasets import load_dataset # Load the BL dataset as a streaming iterator dataset_iter = load_dataset( "TheBritishLibrary/blbooks", split="train", streaming=True, # Enable streaming trust_remote_code=True ).shuffle(seed=42) # Shuffle added # Load the model model = GLiNER.from_pretrained("urchade/gliner_multi-v2.1", trust_remote_code=True) def ner(text: str, labels: str, threshold: float, nested_ner: bool): # Convert user-provided labels (comma-separated string) into a list labels_list = [label.strip() for label in labels.split(",")] # Truncate the text to avoid length exceeding model limits (e.g., 384 tokens) max_length = 384 truncated_text = text[:max_length] # Predict entities using the GLiNER model entities = model.predict_entities(truncated_text, labels_list, flat_ner=not nested_ner, threshold=threshold) # Prepare entities for color-coded display using gr.HighlightedText highlights = [(ent["start"], ent["end"], ent["label"]) for ent in entities] # Return both the highlighted text and the raw entities in JSON format return { "text": truncated_text, "entities": highlights }, entities # Return both outputs: the first for HighlightedText, the second for JSON with gr.Blocks(title="General NER with Color-Coded Output") as demo: gr.Markdown( """ # General Entity Recognition Demo This demo selects a random text snippet from a subset of the British Library's books dataset and identifies entities using a fine-tuned GLiNER model. You can specify the entities you want to find, and results will be displayed in a color-coded format. """ ) # Display a random example input_text = gr.Textbox( value="Click 'Get New Snippet' to load a random sample from the British Library Dataset", label="Text input", placeholder="Enter your text here", lines=5 ) with gr.Row() as row: labels = gr.Textbox( value="People, Places", # Default example labels label="Labels", placeholder="Enter your labels here (comma separated)", scale=2, ) threshold = gr.Slider( 0, 1, value=0.5, # Adjusted to match the threshold used in the function step=0.01, label="Threshold", info="Lower the threshold to increase how many entities get predicted.", scale=1, ) nested_ner = gr.Checkbox( value=False, label="Nested NER", info="Enable Nested NER?", ) # Define output components using HighlightedText for color-coded display output_highlighted = gr.HighlightedText(label="Predicted Entities") output_entities = gr.JSON(label="Entities") submit_btn = gr.Button("Find Entities!") refresh_btn = gr.Button("Get New Snippet") def get_new_snippet(): attempts = 0 max_attempts = 1000 # Prevent infinite loops for sample in dataset_iter: return sample['text'] return "No more snippets available." # Return this if no valid snippets are found # Connect refresh button refresh_btn.click(fn=get_new_snippet, outputs=input_text) # Connect submit button submit_btn.click( fn=ner, inputs=[input_text, labels, threshold, nested_ner], outputs=[output_highlighted, output_entities] ) demo.queue() demo.launch(debug=True)