File size: 1,947 Bytes
0804cf7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import gradio as gr
import spaces
import torch
from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM

model_id = "textcleanlm/textclean-4B"
model = None
tokenizer = None

def load_model():
    global model, tokenizer
    if model is None:
        tokenizer = AutoTokenizer.from_pretrained(model_id)
        
        # Add padding token if needed
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
        
        # Try different model classes
        for model_class in [AutoModelForSeq2SeqLM, AutoModelForCausalLM, AutoModel]:
            try:
                model = model_class.from_pretrained(
                    model_id,
                    torch_dtype=torch.bfloat16,
                    device_map="auto"
                )
                break
            except:
                continue
                
        if model is None:
            raise ValueError(f"Could not load model {model_id}")
            
    return model, tokenizer

@spaces.GPU(duration=60)
def clean_text(text):
    model, tokenizer = load_model()
    
    inputs = tokenizer(text, return_tensors="pt", max_length=512, truncation=True)
    inputs = {k: v.cuda() for k, v in inputs.items()}
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_length=512,
            num_beams=4,
            early_stopping=True
        )
    
    cleaned_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return cleaned_text

iface = gr.Interface(
    fn=clean_text,
    inputs=gr.Textbox(
        lines=5,
        placeholder="Enter text to clean...",
        label="Input Text"
    ),
    outputs=gr.Textbox(
        lines=5,
        label="Cleaned Text"
    ),
    title="TextClean-4B Demo",
    description="Simple demo for text cleaning using textcleanlm/textclean-4B model"
)

if __name__ == "__main__":
    iface.launch()