nova-6.7b / app.py
ejschwartz's picture
Print normalized asm
286e134
import os
import re
os.environ["TOKENIZERS_PARALLELISM"] = "true"
import frontmatter
import gradio as gr
import json
import spaces
import torch
from normalize import normalize
from transformers import AutoTokenizer
from modeling_nova import NovaTokenizer, NovaForCausalLM
def fix_assembly_tabs(asm_text):
"""
Fix assembly code formatting by ensuring proper tab placement.
Expected format: address:TABhex_bytesWHITESPACEinstructionWHITESPACEoperands
"""
lines = asm_text.split("\n")
fixed_lines = []
for line in lines:
line = line.rstrip() # Remove trailing whitespace
if not line.strip(): # Skip empty lines
fixed_lines.append(line)
continue
# Check if this looks like an assembly instruction line
# Pattern: optional_spaces + hex_address + colon + hex_bytes + instruction + operands
asm_pattern = r"^(\s*)([0-9a-f]+):\s*([0-9a-f\s]+?)\s+(\w+)(\s+.*)?$"
match = re.match(asm_pattern, line, re.IGNORECASE)
if match:
indent, address, hex_bytes, instruction, operands = match.groups()
operands = operands or ""
# Clean up hex bytes (remove extra spaces)
hex_bytes = re.sub(r"\s+", " ", hex_bytes.strip())
# Reconstruct with proper tab formatting
# Format: indent + address + ":" + TAB + hex_bytes + TAB + instruction + operands
fixed_line = f"{indent}{address}:\t{hex_bytes}\t{instruction}{operands}"
fixed_lines.append(fixed_line)
else:
# Not an assembly instruction line, keep as is
fixed_lines.append(line)
return "\n".join(fixed_lines)
print("Downloading model")
tokenizer = AutoTokenizer.from_pretrained(
"lt-asset/nova-6.7b-bcr", trust_remote_code=True
)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
nova_tokenizer = NovaTokenizer(tokenizer)
model = NovaForCausalLM.from_pretrained(
"lt-asset/nova-6.7b-bcr", torch_dtype=torch.bfloat16, device_map="auto"
).eval()
examples = json.load(open("humaneval_decompile_nova_6.7b.json", "r"))
@spaces.GPU
def predict(type, input_asm, _c_source):
if "<func0>:" not in input_asm:
# Needs normalizing
# Add a bogus function header if needed.
first_line = input_asm.split("\n")[0]
if "<" not in first_line or ">" not in first_line:
print("Adding synthetic function header")
input_asm = "<func0>:\n" + input_asm
# Fix tab formatting in assembly code
input_asm = fix_assembly_tabs(input_asm)
# Normalizing
normalized_asm = normalize(input_asm)
print(f"Normalized asm: {normalized_asm}")
else:
normalized_asm = input_asm
prompt_before = f"# This is the assembly code with {type} optimization:\n<func0>:"
asm = normalized_asm.strip()
assert asm.startswith("<func0>:")
asm = asm[len("<func0>:") :]
prompt_after = "\nWhat is the source code?\n"
inputs = prompt_before + asm + prompt_after
print("Inputs:", inputs)
# 0 for non-assembly code characters and 1 for assembly characters, required by nova tokenizer
char_types = "0" * len(prompt_before) + "1" * len(asm) + "0" * len(prompt_after)
tokenizer_output = nova_tokenizer.encode(inputs, "", char_types)
input_ids = torch.LongTensor(tokenizer_output["input_ids"].tolist()).unsqueeze(0)
print("Input IDs:", input_ids.shape)
nova_attention_mask = torch.LongTensor(
tokenizer_output["nova_attention_mask"]
).unsqueeze(0)
output = model.generate(
inputs=input_ids.cuda(),
max_new_tokens=512,
temperature=0.2,
top_p=0.95,
num_return_sequences=1,
do_sample=True,
nova_attention_mask=nova_attention_mask.cuda(),
no_mask_idx=torch.LongTensor([tokenizer_output["no_mask_idx"]]).cuda(),
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
)
print("Output 1:", output)
output = tokenizer.decode(
output[0][input_ids.size(1) :],
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
)
print("Output 2:", output)
return output
example = """ 0: f3 0f 1e fa endbr64
4: 55 push %rbp
5: 48 89 e5 mov %rsp,%rbp
8: 89 7d fc mov %edi,-0x4(%rbp)
b: 8b 45 fc mov -0x4(%rbp),%eax
e: 83 c0 2a add $0x2a,%eax
11: 5d pop %rbp
12: c3 ret
"""
demo = gr.Interface(
fn=predict,
inputs=[
gr.Text(label="Optimization Type", value="O0"),
gr.Text(label="Assembly Code (Normalized or not)", value=example),
gr.Text(label="Original C Code"),
],
outputs=gr.Text(label="Raw Nova Output"),
description=frontmatter.load("README.md").content,
examples=[[ex["type"], ex["normalized_asm"], ex["c_func"]] for ex in examples],
)
demo.launch(show_error=True)