Spaces:
Sleeping
Sleeping
import gradio as gr | |
import os, gc, copy | |
from huggingface_hub import hf_hub_download | |
from pynvml import * | |
# Flag to check if GPU is present | |
HAS_GPU = False | |
# Model title and context size limit | |
ctx_limit = 2000 | |
title = "RWKV-5-World-1B5-v2-Translator" | |
model_file = "RWKV-5-World-1B5-v2-20231025-ctx4096" | |
# Get the GPU count | |
try: | |
nvmlInit() | |
GPU_COUNT = nvmlDeviceGetCount() | |
if GPU_COUNT > 0: | |
HAS_GPU = True | |
gpu_h = nvmlDeviceGetHandleByIndex(0) | |
except NVMLError as error: | |
print(error) | |
os.environ["RWKV_JIT_ON"] = '1' | |
# Model strategy to use | |
MODEL_STRAT = "cpu bf16" | |
os.environ["RWKV_CUDA_ON"] = '0' # if '1' then use CUDA kernel for seq mode (much faster) | |
# Switch to GPU mode | |
if HAS_GPU: | |
os.environ["RWKV_CUDA_ON"] = '1' | |
MODEL_STRAT = "cuda bf16" | |
# Load the model | |
from rwkv.model import RWKV | |
model_path = hf_hub_download(repo_id="BlinkDL/rwkv-5-world", filename=f"{model_file}.pth") | |
model = RWKV(model=model_path, strategy=MODEL_STRAT) | |
from rwkv.utils import PIPELINE | |
pipeline = PIPELINE(model, "rwkv_vocab_v20230424") | |
# Precomputation of the state | |
def precompute_state(text): | |
state = None | |
text_encoded = pipeline.encode(text) | |
_, state = model.forward(text_encoded, state) | |
yield dict(state) | |
# Precomputing the base instruction set | |
INSTRUCT_PREFIX = f''' | |
The following is a set of instruction rules, that can translate spoken text to zombie speak. And vice visa. | |
# Zombie Speak Rules: | |
- Replace syllables with "uh" or "argh" | |
- Add "uh" and "argh" sounds between words | |
- Repeat words and letters, especially vowels | |
- Use broken grammar and omit small words like "the", "a", "is" | |
# To go from zombie speak back to English: | |
- Remove extra "uh" and "argh" sounds | |
- Replace repeated letters with one instance | |
- Add omitted small words like "the", "a", "is" back in | |
- Fix grammar and sentence structure | |
# Here are several examples: | |
## English: | |
"Hello my friend, how are you today?" | |
## Zombie: | |
"Hell-uh-argh myuh fruh-end, hargh-owuh argh yuh-uh toduh-ay?" | |
## Zombie: | |
"Brargh-ains argh-uh foo-duh" | |
## English: | |
"Brains are food" | |
## English: | |
"Good morning! How are you today? I hope you are having a nice day. The weather is supposed to be sunny and warm this afternoon. Maybe we could go for a nice walk together and stop to get ice cream. That would be very enjoyable. Well, I will talk to you soon!" | |
## Zombie: | |
"Guh-ood morngh-ing! Hargh-owuh argh yuh-uh toduh-ay? Iuh hargh-ope yuh-uh argh havi-uh-nguh nuh-ice duh-ay. Thuh-uh weath-uh-eruh izzuh suh-pose-duh tuh-uh beh sunn-eh an-duh war-muh thizuh aft-erng-oon. May-buh-uh weh coulduh gargh-oh fargh-oruh nuh-ice wal-guh-kuh toge-the-ruh an-duh stargh-op tuh-uh geh-etuh izz-creem. Tha-at wou-duh beh ve-reh uhn-joy-ab-buhl. Well, I wih-ll targh-alk tuh-uh yuh-oo soo-oon!" | |
''' | |
# Get the prefix state | |
PREFIX_STATE = precompute_state(INSTRUCT_PREFIX) | |
# Translation logic | |
def translate(text, target_language, inState=PREFIX_STATE): | |
prompt = f"Translate the following text to {target_language}\n # Input Text:\n{text}\n\n# Output Text:\n" | |
ctx = prompt.strip() | |
all_tokens = [] | |
out_last = 0 | |
out_str = '' | |
occurrence = {} | |
state = None | |
if inState != None: | |
state = dict(inState) | |
# Clear GC | |
gc.collect() | |
if HAS_GPU == True : | |
torch.cuda.empty_cache() | |
# Generate things token by token | |
for i in range(ctx_limit): | |
out, state = model.forward(pipeline.encode(ctx)[-ctx_limit:] if i == 0 else [token], state) | |
token = pipeline.sample_logits(out) | |
if token in [0]: # EOS token | |
break | |
all_tokens += [token] | |
tmp = pipeline.decode(all_tokens[out_last:]) | |
if '\ufffd' not in tmp: | |
out_str += tmp | |
yield out_str.strip() | |
out_last = i + 1 | |
if "# " in out_str and "\n#" in out_str : | |
out_str = out_str.split("\n## ")[0].split("\n# ")[0] | |
yield out_str.strip() | |
del out | |
del state | |
# # Clear GC | |
# gc.collect() | |
# if HAS_GPU == True : | |
# torch.cuda.empty_cache() | |
yield out_str.strip() | |
# Languages | |
LANGUAGES = [ | |
"English", | |
"Zombie Speak", | |
"Chinese", | |
"Spanish", | |
"Bengali", | |
"Hindi", | |
"Portuguese", | |
"Russian", | |
"Japanese", | |
"German", | |
"Chinese (Wu)", | |
"Javanese", | |
"Korean", | |
"French", | |
"Vietnamese", | |
"Telugu", | |
"Chinese (Yue)", | |
"Marathi", | |
"Tamil", | |
"Turkish", | |
"Urdu", | |
"Chinese (Min Nan)", | |
"Chinese (Jin Yu)", | |
"Gujarati", | |
"Polish", | |
"Arabic (Egyptian Spoken)", | |
"Ukrainian", | |
"Italian", | |
"Chinese (Xiang)", | |
"Malayalam", | |
"Chinese (Hakka)", | |
"Kannada", | |
"Oriya", | |
"Panjabi (Western)", | |
"Panjabi (Eastern)", | |
"Sunda", | |
"Romanian", | |
"Bhojpuri", | |
"Azerbaijani (South)", | |
"Farsi (Western)", | |
"Maithili", | |
"Hausa", | |
"Arabic (Algerian Spoken)", | |
"Burmese", | |
"Serbo-Croatian", | |
"Chinese (Gan)", | |
"Awadhi", | |
"Thai", | |
"Dutch", | |
"Yoruba", | |
"Sindhi", | |
"Arabic (Moroccan Spoken)", | |
"Arabic (Saidi Spoken)", | |
"Uzbek, Northern", | |
"Malay", | |
"Amharic", | |
"Indonesian", | |
"Igbo", | |
"Tagalog", | |
"Nepali", | |
"Arabic (Sudanese Spoken)", | |
"Saraiki", | |
"Cebuano", | |
"Arabic (North Levantine Spoken)", | |
"Thai (Northeastern)", | |
"Assamese", | |
"Hungarian", | |
"Chittagonian", | |
"Arabic (Mesopotamian Spoken)", | |
"Madura", | |
"Sinhala", | |
"Haryanvi", | |
"Marwari", | |
"Czech", | |
"Greek", | |
"Magahi", | |
"Chhattisgarhi", | |
"Deccan", | |
"Chinese (Min Bei)", | |
"Belarusan", | |
"Zhuang (Northern)", | |
"Arabic (Najdi Spoken)", | |
"Pashto (Northern)", | |
"Somali", | |
"Malagasy", | |
"Arabic (Tunisian Spoken)", | |
"Rwanda", | |
"Zulu", | |
"Bulgarian", | |
"Swedish", | |
"Lombard", | |
"Oromo (West-central)", | |
"Pashto (Southern)", | |
"Kazakh", | |
"Ilocano", | |
"Tatar", | |
"Fulfulde (Nigerian)", | |
"Arabic (Sanaani Spoken)", | |
"Uyghur", | |
"Haitian Creole French", | |
"Azerbaijani, North", | |
"Napoletano-calabrese", | |
"Khmer (Central)", | |
"Farsi (Eastern)", | |
"Akan", | |
"Hiligaynon", | |
"Kurmanji", | |
"Shona" | |
] | |
# Example data | |
EXAMPLES = [ | |
["Brargh-ains argh-uh foo-duh", "English"], | |
["I Want to eat your brains", "Zombie Speak"], | |
["Hello, how are you?", "French"], | |
["Hello, how are you?", "Spanish"], | |
["Hello, how are you?", "Chinese"], | |
["Bonjour, comment ça va?", "English"], | |
["Hola, ¿cómo estás?", "English"], | |
["你好吗?", "English"], | |
["Guten Tag, wie geht es Ihnen?", "English"], | |
["Привет, как ты?", "English"], | |
["مرحبًا ، كيف حالك؟", "English"], | |
] | |
# Gradio interface | |
with gr.Blocks(title=title) as demo: | |
gr.HTML(f"<div style=\"text-align: center;\"><h1>RWKV-5 World v2 - {title}</h1></div>") | |
gr.Markdown("This is the RWKV-5 World v2 1B5 model tailored for translation. With a halloween zombie speak twist") | |
# Input and output components | |
text = gr.Textbox(lines=5, label="Source Text", placeholder="Enter the text you want to translate...", default=EXAMPLES[0][0]) | |
target_language = gr.Dropdown(choices=LANGUAGES, label="Target Language", default=EXAMPLES[0][1]) | |
output = gr.Textbox(lines=5, label="Translated Text") | |
submit = gr.Button("Translate", variant="primary") | |
# Example data | |
data = gr.Dataset(components=[text, target_language], samples=EXAMPLES, label="Example Translations", headers=["Text", "Target Language"]) | |
# Button action | |
submit.click(translate, [text, target_language], [output]) | |
data.click(lambda x: x, [data], [text, target_language]) | |
# Gradio launch | |
demo.queue(concurrency_count=1, max_size=10) | |
demo.launch(share=False) |