zehui127's picture
update
3606fab
import os
import re
import torch
import gradio as gr
from tqdm import tqdm
from datasets import load_dataset, DatasetDict
from transformers import AutoModelForCausalLM, AutoTokenizer
# Automatically detect GPU or use CPU
device = "cuda" if torch.cuda.is_available() else "cpu"
# Default model path
model_tokenizer_path = "zehui127/Omni-DNA-Multitask"
# Load tokenizer and model with trusted remote code
tokenizer = AutoTokenizer.from_pretrained(model_tokenizer_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_tokenizer_path, trust_remote_code=True).to(device)
# List of available tasks
tasks = ['H3', 'H4', 'H3K9ac', 'H3K14ac', 'H4ac', 'H3K4me1', 'H3K4me2', 'H3K4me3', 'H3K36me3', 'H3K79me3']
mapping = {'1':'It is a',
'0':'It is not a',
'No valid prediction':'Cannot be determined whether or not it is a',
}
def preprocess_response(response, mask_token="[MASK]"):
"""Extracts the response after the [MASK] token."""
if mask_token in response:
response = response.split(mask_token, 1)[1]
response = re.sub(r'^[\sATGC]+', '', response)
return response
def generate(dna_sequence, task_type, sample_num=1):
"""
Generates a response based on the DNA sequence and selected task.
Args:
dna_sequence (str): The input DNA sequence.
task_type (str): The selected task type.
sample_num (int): Number of samples for the generation process.
Returns:
str: Predicted function label.
"""
if task_type is None:
task_type = 'H3'
dna_sequence = dna_sequence + task_type +"[MASK]"
tokenized_message = tokenizer(
[dna_sequence], return_tensors='pt', return_token_type_ids=False, add_special_tokens=True
).to(device)
response = model.generate(**tokenized_message, max_new_tokens=sample_num, do_sample=False)
reply = tokenizer.batch_decode(response, skip_special_tokens=False)[0].replace(" ", "")
pred = extract_label(reply, task_type)
return f"{mapping[pred]} {task_type}"
def extract_label(message, task_type):
"""Extracts the prediction label from the model's response."""
task_type = '[MASK]'
answer = message.split(task_type)[1]
match = re.search(r'\d+', answer)
return match.group() if match else "No valid prediction"
# Gradio interface
interface = gr.Interface(
fn=generate,
inputs=[
gr.Textbox(label="Input DNA Sequence", placeholder="Enter a DNA sequence"),
gr.Dropdown(choices=tasks, label="Select Task Type"),
],
outputs=gr.Textbox(label="Predicted Type"),
title="Omni-DNA Multitask Prediction",
description="Select a DNA-related task and input a sequence to generate function predictions.",
)
if __name__ == "__main__":
interface.launch()