Spaces:
Sleeping
Sleeping
File size: 2,795 Bytes
77fb246 58ac50d 77fb246 58ac50d 77fb246 58ac50d 77fb246 58ac50d 77fb246 58ac50d 77fb246 3606fab 77fb246 58ac50d 77fb246 58ac50d 77fb246 58ac50d 77fb246 26ab4df 77fb246 92ec44f 77fb246 58ac50d 77fb246 3606fab 77fb246 58ac50d 3606fab 77fb246 58ac50d 77fb246 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
import os
import re
import torch
import gradio as gr
from tqdm import tqdm
from datasets import load_dataset, DatasetDict
from transformers import AutoModelForCausalLM, AutoTokenizer
# Automatically detect GPU or use CPU
device = "cuda" if torch.cuda.is_available() else "cpu"
# Default model path
model_tokenizer_path = "zehui127/Omni-DNA-Multitask"
# Load tokenizer and model with trusted remote code
tokenizer = AutoTokenizer.from_pretrained(model_tokenizer_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_tokenizer_path, trust_remote_code=True).to(device)
# List of available tasks
tasks = ['H3', 'H4', 'H3K9ac', 'H3K14ac', 'H4ac', 'H3K4me1', 'H3K4me2', 'H3K4me3', 'H3K36me3', 'H3K79me3']
mapping = {'1':'It is a',
'0':'It is not a',
'No valid prediction':'Cannot be determined whether or not it is a',
}
def preprocess_response(response, mask_token="[MASK]"):
"""Extracts the response after the [MASK] token."""
if mask_token in response:
response = response.split(mask_token, 1)[1]
response = re.sub(r'^[\sATGC]+', '', response)
return response
def generate(dna_sequence, task_type, sample_num=1):
"""
Generates a response based on the DNA sequence and selected task.
Args:
dna_sequence (str): The input DNA sequence.
task_type (str): The selected task type.
sample_num (int): Number of samples for the generation process.
Returns:
str: Predicted function label.
"""
if task_type is None:
task_type = 'H3'
dna_sequence = dna_sequence + task_type +"[MASK]"
tokenized_message = tokenizer(
[dna_sequence], return_tensors='pt', return_token_type_ids=False, add_special_tokens=True
).to(device)
response = model.generate(**tokenized_message, max_new_tokens=sample_num, do_sample=False)
reply = tokenizer.batch_decode(response, skip_special_tokens=False)[0].replace(" ", "")
pred = extract_label(reply, task_type)
return f"{mapping[pred]} {task_type}"
def extract_label(message, task_type):
"""Extracts the prediction label from the model's response."""
task_type = '[MASK]'
answer = message.split(task_type)[1]
match = re.search(r'\d+', answer)
return match.group() if match else "No valid prediction"
# Gradio interface
interface = gr.Interface(
fn=generate,
inputs=[
gr.Textbox(label="Input DNA Sequence", placeholder="Enter a DNA sequence"),
gr.Dropdown(choices=tasks, label="Select Task Type"),
],
outputs=gr.Textbox(label="Predicted Type"),
title="Omni-DNA Multitask Prediction",
description="Select a DNA-related task and input a sequence to generate function predictions.",
)
if __name__ == "__main__":
interface.launch()
|