File size: 5,071 Bytes
43e3573
48d9b27
f944b14
43e3573
 
fcdf906
 
 
 
 
dcd8edc
fcdf906
 
 
 
ad23357
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fcdf906
 
f944b14
 
 
fcdf906
 
 
 
f944b14
 
 
fcdf906
cfda2d8
fcdf906
f944b14
fcdf906
dcd8edc
 
 
ad23357
c74a501
 
 
 
36fe8f3
c74a501
 
ad23357
 
 
286e134
dcd8edc
286e134
dcd8edc
 
fcdf906
f944b14
fcdf906
f944b14
 
 
 
fcdf906
e73ac0b
 
fcdf906
f944b14
 
 
 
347b36e
f944b14
 
 
fcdf906
ddcc7db
f944b14
 
 
 
 
 
 
 
 
 
fcdf906
94dc66c
fcdf906
f944b14
 
 
 
 
94dc66c
fcdf906
 
 
36fe8f3
d810520
 
 
 
 
 
 
 
f944b14
fcdf906
 
f944b14
 
d810520
f15518a
f944b14
fcdf906
 
edb2e2b
fcdf906
f81d7bd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import os
import re

os.environ["TOKENIZERS_PARALLELISM"] = "true"

import frontmatter
import gradio as gr
import json
import spaces
import torch
from normalize import normalize

from transformers import AutoTokenizer
from modeling_nova import NovaTokenizer, NovaForCausalLM


def fix_assembly_tabs(asm_text):
    """
    Fix assembly code formatting by ensuring proper tab placement.
    Expected format: address:TABhex_bytesWHITESPACEinstructionWHITESPACEoperands
    """
    lines = asm_text.split("\n")
    fixed_lines = []

    for line in lines:
        line = line.rstrip()  # Remove trailing whitespace
        if not line.strip():  # Skip empty lines
            fixed_lines.append(line)
            continue

        # Check if this looks like an assembly instruction line
        # Pattern: optional_spaces + hex_address + colon + hex_bytes + instruction + operands
        asm_pattern = r"^(\s*)([0-9a-f]+):\s*([0-9a-f\s]+?)\s+(\w+)(\s+.*)?$"
        match = re.match(asm_pattern, line, re.IGNORECASE)

        if match:
            indent, address, hex_bytes, instruction, operands = match.groups()
            operands = operands or ""

            # Clean up hex bytes (remove extra spaces)
            hex_bytes = re.sub(r"\s+", " ", hex_bytes.strip())

            # Reconstruct with proper tab formatting
            # Format: indent + address + ":" + TAB + hex_bytes + TAB + instruction + operands
            fixed_line = f"{indent}{address}:\t{hex_bytes}\t{instruction}{operands}"
            fixed_lines.append(fixed_line)
        else:
            # Not an assembly instruction line, keep as is
            fixed_lines.append(line)

    return "\n".join(fixed_lines)


print("Downloading model")

tokenizer = AutoTokenizer.from_pretrained(
    "lt-asset/nova-6.7b-bcr", trust_remote_code=True
)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
nova_tokenizer = NovaTokenizer(tokenizer)

model = NovaForCausalLM.from_pretrained(
    "lt-asset/nova-6.7b-bcr", torch_dtype=torch.bfloat16, device_map="auto"
).eval()

examples = json.load(open("humaneval_decompile_nova_6.7b.json", "r"))


@spaces.GPU
def predict(type, input_asm, _c_source):

    if "<func0>:" not in input_asm:
        # Needs normalizing

        # Add a bogus function header if needed.
        first_line = input_asm.split("\n")[0]
        if "<" not in first_line or ">" not in first_line:
            print("Adding synthetic function header")
            input_asm = "<func0>:\n" + input_asm

        # Fix tab formatting in assembly code
        input_asm = fix_assembly_tabs(input_asm)

        # Normalizing
        normalized_asm = normalize(input_asm)
        print(f"Normalized asm: {normalized_asm}")
    else:
        normalized_asm = input_asm

    prompt_before = f"# This is the assembly code with {type} optimization:\n<func0>:"
    asm = normalized_asm.strip()
    assert asm.startswith("<func0>:")
    asm = asm[len("<func0>:") :]
    prompt_after = "\nWhat is the source code?\n"

    inputs = prompt_before + asm + prompt_after
    print("Inputs:", inputs)

    # 0 for non-assembly code characters and 1 for assembly characters, required by nova tokenizer
    char_types = "0" * len(prompt_before) + "1" * len(asm) + "0" * len(prompt_after)

    tokenizer_output = nova_tokenizer.encode(inputs, "", char_types)
    input_ids = torch.LongTensor(tokenizer_output["input_ids"].tolist()).unsqueeze(0)
    print("Input IDs:", input_ids.shape)
    nova_attention_mask = torch.LongTensor(
        tokenizer_output["nova_attention_mask"]
    ).unsqueeze(0)

    output = model.generate(
        inputs=input_ids.cuda(),
        max_new_tokens=512,
        temperature=0.2,
        top_p=0.95,
        num_return_sequences=1,
        do_sample=True,
        nova_attention_mask=nova_attention_mask.cuda(),
        no_mask_idx=torch.LongTensor([tokenizer_output["no_mask_idx"]]).cuda(),
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer.eos_token_id,
    )
    print("Output 1:", output)

    output = tokenizer.decode(
        output[0][input_ids.size(1) :],
        skip_special_tokens=True,
        clean_up_tokenization_spaces=True,
    )
    print("Output 2:", output)

    return output

example = """   0:	f3 0f 1e fa          	endbr64 
   4:	55                   	push   %rbp
   5:	48 89 e5             	mov    %rsp,%rbp
   8:	89 7d fc             	mov    %edi,-0x4(%rbp)
   b:	8b 45 fc             	mov    -0x4(%rbp),%eax
   e:	83 c0 2a             	add    $0x2a,%eax
  11:	5d                   	pop    %rbp
  12:	c3                   	ret    
"""

demo = gr.Interface(
    fn=predict,
    inputs=[
        gr.Text(label="Optimization Type", value="O0"),
        gr.Text(label="Assembly Code (Normalized or not)", value=example),
        gr.Text(label="Original C Code"),
    ],
    outputs=gr.Text(label="Raw Nova Output"),
    description=frontmatter.load("README.md").content,
    examples=[[ex["type"], ex["normalized_asm"], ex["c_func"]] for ex in examples],
)
demo.launch(show_error=True)