KavinduHansaka's picture
Update app.py
443c9b1 verified
raw
history blame
4.91 kB
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, AutoModelForSequenceClassification
import os
from huggingface_hub import login
import torch
# Authenticate with Hugging Face
HUGGINGFACE_TOKEN = os.getenv("HF_TOKEN")
login(token=HUGGINGFACE_TOKEN)
# Load Phi-4 Mini
phi_model_id = "microsoft/phi-4-mini-instruct"
phi_tokenizer = AutoTokenizer.from_pretrained(phi_model_id, token=HUGGINGFACE_TOKEN)
phi_model = AutoModelForCausalLM.from_pretrained(
phi_model_id, torch_dtype="auto", device_map="auto", token=HUGGINGFACE_TOKEN
)
phi_pipe = pipeline("text-generation", model=phi_model, tokenizer=phi_tokenizer)
# Load T5 for paraphrasing
t5_pipe = pipeline("text2text-generation", model="google-t5/t5-base")
# Load AI Detector
ai_model_id = "openai-community/roberta-base-openai-detector"
ai_tokenizer = AutoTokenizer.from_pretrained(ai_model_id)
ai_model = AutoModelForSequenceClassification.from_pretrained(ai_model_id)
# Text chunking
def chunk_text(text, max_tokens=300):
paragraphs = text.split("\n\n")
chunks, current = [], ""
for para in paragraphs:
if len(current.split()) + len(para.split()) < max_tokens:
current += para + "\n\n"
else:
chunks.append(current.strip())
current = para + "\n\n"
if current.strip():
chunks.append(current.strip())
return chunks
# Phi-based instruction
def generate_phi_prompt(text, instruction):
chunks = chunk_text(text)
outputs = []
for chunk in chunks:
prompt = f"{instruction}\n{chunk}\nResponse:"
result = phi_pipe(prompt, max_new_tokens=256, do_sample=False, temperature=0.3)[0]["generated_text"]
if "Response:" in result:
outputs.append(result.split("Response:")[1].strip())
else:
outputs.append(result.strip())
return "\n\n".join(outputs)
# Functions for each tool
def fix_grammar(text):
return generate_phi_prompt(text, "Correct all grammar and punctuation errors in the following text. Provide only the corrected version:")
def improve_tone(text):
return generate_phi_prompt(text, "Rewrite the following text to sound more formal, polite, and professional:")
def improve_fluency(text):
return generate_phi_prompt(text, "Rewrite the following to improve its clarity, sentence flow, and natural fluency:")
def paraphrase(text):
chunks = chunk_text(text, max_tokens=60)
outputs = []
for chunk in chunks:
output = t5_pipe("paraphrase this sentence: " + chunk, max_length=128, num_beams=5, do_sample=False)[0]["generated_text"]
outputs.append(output)
return "\n\n".join(outputs)
# Upload/download handlers
def load_file(file_obj):
if file_obj is None:
return ""
return file_obj.read().decode("utf-8")
def save_file(text):
path = "/tmp/output.txt"
with open(path, "w", encoding="utf-8") as f:
f.write(text)
return path
# AI Detection function
def detect_ai_text(text):
inputs = ai_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
logits = ai_model(**inputs).logits
probs = torch.softmax(logits, dim=1).squeeze()
return {
"Likely Human": round(probs[0].item(), 2),
"Likely AI-Generated": round(probs[1].item(), 2)
}
# Gradio UI
with gr.Blocks() as demo:
gr.Markdown("# ✍️ AI Writing Assistant + Detector")
gr.Markdown("Fix grammar, improve tone and fluency, paraphrase text, detect AI content, and upload/download files.")
with gr.Row():
file_input = gr.File(label="📂 Upload .txt File", file_types=[".txt"])
load_btn = gr.Button("📥 Load Text")
input_text = gr.Textbox(lines=12, label="Or Paste Text")
load_btn.click(fn=load_file, inputs=file_input, outputs=input_text)
with gr.Row():
btn_grammar = gr.Button("✔️ Fix Grammar")
btn_tone = gr.Button("🎯 Improve Tone")
btn_fluency = gr.Button("🔄 Improve Fluency")
btn_paraphrase = gr.Button("🌀 Paraphrase")
btn_detect = gr.Button("🕵️ Detect AI vs Human")
output_text = gr.Textbox(lines=12, label="Output")
ai_output = gr.Label(label="AI Detection Result")
btn_grammar.click(fn=fix_grammar, inputs=input_text, outputs=output_text)
btn_tone.click(fn=improve_tone, inputs=input_text, outputs=output_text)
btn_fluency.click(fn=improve_fluency, inputs=input_text, outputs=output_text)
btn_paraphrase.click(fn=paraphrase, inputs=input_text, outputs=output_text)
btn_detect.click(fn=detect_ai_text, inputs=input_text, outputs=ai_output)
gr.Markdown("## 📤 Download Output")
download_btn = gr.Button("💾 Download as .txt")
download_file = gr.File(label="Click to download", interactive=True)
download_btn.click(fn=save_file, inputs=output_text, outputs=download_file)
demo.launch()