Spaces:
Sleeping
Sleeping
import os | |
import re | |
import torch | |
import gradio as gr | |
from tqdm import tqdm | |
from datasets import load_dataset, DatasetDict | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
# Automatically detect GPU or use CPU | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Default model path | |
model_tokenizer_path = "zehui127/Omni-DNA-Multitask" | |
# Load tokenizer and model with trusted remote code | |
tokenizer = AutoTokenizer.from_pretrained(model_tokenizer_path, trust_remote_code=True) | |
model = AutoModelForCausalLM.from_pretrained(model_tokenizer_path, trust_remote_code=True).to(device) | |
# List of available tasks | |
tasks = ['H3', 'H4', 'H3K9ac', 'H3K14ac', 'H4ac', 'H3K4me1', 'H3K4me2', 'H3K4me3', 'H3K36me3', 'H3K79me3'] | |
mapping = {'1':'It is a', | |
'0':'It is not a', | |
'No valid prediction':'Cannot be determined whether or not it is a', | |
} | |
def preprocess_response(response, mask_token="[MASK]"): | |
"""Extracts the response after the [MASK] token.""" | |
if mask_token in response: | |
response = response.split(mask_token, 1)[1] | |
response = re.sub(r'^[\sATGC]+', '', response) | |
return response | |
def generate(dna_sequence, task_type, sample_num=1): | |
""" | |
Generates a response based on the DNA sequence and selected task. | |
Args: | |
dna_sequence (str): The input DNA sequence. | |
task_type (str): The selected task type. | |
sample_num (int): Number of samples for the generation process. | |
Returns: | |
str: Predicted function label. | |
""" | |
if task_type is None: | |
task_type = 'H3' | |
dna_sequence = dna_sequence + task_type +"[MASK]" | |
tokenized_message = tokenizer( | |
[dna_sequence], return_tensors='pt', return_token_type_ids=False, add_special_tokens=True | |
).to(device) | |
response = model.generate(**tokenized_message, max_new_tokens=sample_num, do_sample=False) | |
reply = tokenizer.batch_decode(response, skip_special_tokens=False)[0].replace(" ", "") | |
pred = extract_label(reply, task_type) | |
return f"{mapping[pred]} {task_type}" | |
def extract_label(message, task_type): | |
"""Extracts the prediction label from the model's response.""" | |
task_type = '[MASK]' | |
answer = message.split(task_type)[1] | |
match = re.search(r'\d+', answer) | |
return match.group() if match else "No valid prediction" | |
# Gradio interface | |
interface = gr.Interface( | |
fn=generate, | |
inputs=[ | |
gr.Textbox(label="Input DNA Sequence", placeholder="Enter a DNA sequence"), | |
gr.Dropdown(choices=tasks, label="Select Task Type"), | |
], | |
outputs=gr.Textbox(label="Predicted Type"), | |
title="Omni-DNA Multitask Prediction", | |
description="Select a DNA-related task and input a sequence to generate function predictions.", | |
) | |
if __name__ == "__main__": | |
interface.launch() | |