Spaces:
Running
on
Zero
Running
on
Zero
Update raw.py
Browse files
raw.py
CHANGED
@@ -105,34 +105,27 @@ def caption(input_image: Image.Image, prompt: str, temperature: float, top_p: fl
|
|
105 |
# WARNING: HF's handling of chat's on Llava models is very fragile. This specific combination of processor.apply_chat_template(), and processor() works
|
106 |
# but if using other combinations always inspect the final input_ids to ensure they are correct. Often times you will end up with multiple <bos> tokens
|
107 |
# if not careful, which can make the model perform poorly.
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
t = Thread(target=model.generate, kwargs=generate_kwargs)
|
130 |
-
t.start()
|
131 |
-
|
132 |
-
outputs = []
|
133 |
-
for text in streamer:
|
134 |
-
outputs.append(text)
|
135 |
-
yield "".join(outputs)
|
136 |
|
137 |
@spaces.GPU()
|
138 |
@torch.no_grad()
|
|
|
105 |
# WARNING: HF's handling of chat's on Llava models is very fragile. This specific combination of processor.apply_chat_template(), and processor() works
|
106 |
# but if using other combinations always inspect the final input_ids to ensure they are correct. Often times you will end up with multiple <bos> tokens
|
107 |
# if not careful, which can make the model perform poorly.
|
108 |
+
convo_string = cap_processor.apply_chat_template(convo, tokenize=False, add_generation_prompt=True)
|
109 |
+
assert isinstance(convo_string, str)
|
110 |
+
inputs = cap_processor(text=[convo_string], images=[input_image], return_tensors="pt").to('cuda')
|
111 |
+
inputs['pixel_values'] = inputs['pixel_values'].to(torch.bfloat16)
|
112 |
+
streamer = TextIteratorStreamer(cap_processor.tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
|
113 |
+
generate_kwargs = dict(
|
114 |
+
**inputs,
|
115 |
+
max_new_tokens=max_new_tokens,
|
116 |
+
do_sample=True if temperature > 0 else False,
|
117 |
+
suppress_tokens=None,
|
118 |
+
use_cache=True,
|
119 |
+
temperature=temperature if temperature > 0 else None,
|
120 |
+
top_k=None,
|
121 |
+
top_p=top_p if temperature > 0 else None,
|
122 |
+
streamer=streamer,
|
123 |
+
)
|
124 |
+
_ = cap_model.generate(**generate_kwargs)
|
125 |
+
outputs = []
|
126 |
+
for text in streamer:
|
127 |
+
outputs.append(text)
|
128 |
+
yield "".join(outputs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
129 |
|
130 |
@spaces.GPU()
|
131 |
@torch.no_grad()
|