|
import gradio as gr |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, AutoModelForSequenceClassification |
|
import os |
|
from huggingface_hub import login |
|
import torch |
|
|
|
|
|
HUGGINGFACE_TOKEN = os.getenv("HF_TOKEN") |
|
login(token=HUGGINGFACE_TOKEN) |
|
|
|
|
|
phi_model_id = "microsoft/phi-4-mini-instruct" |
|
phi_tokenizer = AutoTokenizer.from_pretrained(phi_model_id, token=HUGGINGFACE_TOKEN) |
|
phi_model = AutoModelForCausalLM.from_pretrained( |
|
phi_model_id, torch_dtype="auto", device_map="auto", token=HUGGINGFACE_TOKEN |
|
) |
|
phi_pipe = pipeline("text-generation", model=phi_model, tokenizer=phi_tokenizer) |
|
|
|
|
|
t5_pipe = pipeline("text2text-generation", model="google-t5/t5-base") |
|
|
|
|
|
ai_model_id = "openai-community/roberta-base-openai-detector" |
|
ai_tokenizer = AutoTokenizer.from_pretrained(ai_model_id) |
|
ai_model = AutoModelForSequenceClassification.from_pretrained(ai_model_id) |
|
|
|
|
|
def chunk_text(text, max_tokens=300): |
|
paragraphs = text.split("\n\n") |
|
chunks, current = [], "" |
|
for para in paragraphs: |
|
if len(current.split()) + len(para.split()) < max_tokens: |
|
current += para + "\n\n" |
|
else: |
|
chunks.append(current.strip()) |
|
current = para + "\n\n" |
|
if current.strip(): |
|
chunks.append(current.strip()) |
|
return chunks |
|
|
|
|
|
def generate_phi_prompt(text, instruction): |
|
chunks = chunk_text(text) |
|
outputs = [] |
|
for chunk in chunks: |
|
prompt = f"{instruction}\n{chunk}\nResponse:" |
|
result = phi_pipe(prompt, max_new_tokens=256, do_sample=False, temperature=0.3)[0]["generated_text"] |
|
if "Response:" in result: |
|
outputs.append(result.split("Response:")[1].strip()) |
|
else: |
|
outputs.append(result.strip()) |
|
return "\n\n".join(outputs) |
|
|
|
|
|
def fix_grammar(text): |
|
return generate_phi_prompt(text, "Correct all grammar and punctuation errors in the following text. Provide only the corrected version:") |
|
|
|
def improve_tone(text): |
|
return generate_phi_prompt(text, "Rewrite the following text to sound more formal, polite, and professional:") |
|
|
|
def improve_fluency(text): |
|
return generate_phi_prompt(text, "Rewrite the following to improve its clarity, sentence flow, and natural fluency:") |
|
|
|
def paraphrase(text): |
|
chunks = chunk_text(text, max_tokens=60) |
|
outputs = [] |
|
for chunk in chunks: |
|
output = t5_pipe("paraphrase this sentence: " + chunk, max_length=128, num_beams=5, do_sample=False)[0]["generated_text"] |
|
outputs.append(output) |
|
return "\n\n".join(outputs) |
|
|
|
|
|
def load_file(file_obj): |
|
if file_obj is None: |
|
return "" |
|
return file_obj.read().decode("utf-8") |
|
|
|
def save_file(text): |
|
path = "/tmp/output.txt" |
|
with open(path, "w", encoding="utf-8") as f: |
|
f.write(text) |
|
return path |
|
|
|
|
|
def detect_ai_text(text): |
|
inputs = ai_tokenizer(text, return_tensors="pt", truncation=True, padding=True) |
|
with torch.no_grad(): |
|
logits = ai_model(**inputs).logits |
|
probs = torch.softmax(logits, dim=1).squeeze() |
|
return { |
|
"Likely Human": round(probs[0].item(), 2), |
|
"Likely AI-Generated": round(probs[1].item(), 2) |
|
} |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# ✍️ AI Writing Assistant + Detector") |
|
gr.Markdown("Fix grammar, improve tone and fluency, paraphrase text, detect AI content, and upload/download files.") |
|
|
|
with gr.Row(): |
|
file_input = gr.File(label="📂 Upload .txt File", file_types=[".txt"]) |
|
load_btn = gr.Button("📥 Load Text") |
|
input_text = gr.Textbox(lines=12, label="Or Paste Text") |
|
|
|
load_btn.click(fn=load_file, inputs=file_input, outputs=input_text) |
|
|
|
with gr.Row(): |
|
btn_grammar = gr.Button("✔️ Fix Grammar") |
|
btn_tone = gr.Button("🎯 Improve Tone") |
|
btn_fluency = gr.Button("🔄 Improve Fluency") |
|
btn_paraphrase = gr.Button("🌀 Paraphrase") |
|
btn_detect = gr.Button("🕵️ Detect AI vs Human") |
|
|
|
output_text = gr.Textbox(lines=12, label="Output") |
|
ai_output = gr.Label(label="AI Detection Result") |
|
|
|
btn_grammar.click(fn=fix_grammar, inputs=input_text, outputs=output_text) |
|
btn_tone.click(fn=improve_tone, inputs=input_text, outputs=output_text) |
|
btn_fluency.click(fn=improve_fluency, inputs=input_text, outputs=output_text) |
|
btn_paraphrase.click(fn=paraphrase, inputs=input_text, outputs=output_text) |
|
btn_detect.click(fn=detect_ai_text, inputs=input_text, outputs=ai_output) |
|
|
|
gr.Markdown("## 📤 Download Output") |
|
download_btn = gr.Button("💾 Download as .txt") |
|
download_file = gr.File(label="Click to download", interactive=True) |
|
download_btn.click(fn=save_file, inputs=output_text, outputs=download_file) |
|
|
|
demo.launch() |