LPX55 commited on
Commit
20d1485
·
verified ·
1 Parent(s): a530f5c

Update raw.py

Browse files
Files changed (1) hide show
  1. raw.py +21 -28
raw.py CHANGED
@@ -105,34 +105,27 @@ def caption(input_image: Image.Image, prompt: str, temperature: float, top_p: fl
105
  # WARNING: HF's handling of chat's on Llava models is very fragile. This specific combination of processor.apply_chat_template(), and processor() works
106
  # but if using other combinations always inspect the final input_ids to ensure they are correct. Often times you will end up with multiple <bos> tokens
107
  # if not careful, which can make the model perform poorly.
108
- convo_string = processor.apply_chat_template(convo, tokenize = False, add_generation_prompt = True)
109
- assert isinstance(convo_string, str)
110
-
111
- # Process the inputs
112
- inputs = processor(text=[convo_string], images=[input_image], return_tensors="pt").to('cuda')
113
- inputs['pixel_values'] = inputs['pixel_values'].to(torch.bfloat16)
114
-
115
- streamer = TextIteratorStreamer(processor.tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
116
-
117
- generate_kwargs = dict(
118
- **inputs,
119
- max_new_tokens=max_new_tokens,
120
- do_sample=True if temperature > 0 else False,
121
- suppress_tokens=None,
122
- use_cache=True,
123
- temperature=temperature if temperature > 0 else None,
124
- top_k=None,
125
- top_p=top_p if temperature > 0 else None,
126
- streamer=streamer,
127
- )
128
-
129
- t = Thread(target=model.generate, kwargs=generate_kwargs)
130
- t.start()
131
-
132
- outputs = []
133
- for text in streamer:
134
- outputs.append(text)
135
- yield "".join(outputs)
136
 
137
  @spaces.GPU()
138
  @torch.no_grad()
 
105
  # WARNING: HF's handling of chat's on Llava models is very fragile. This specific combination of processor.apply_chat_template(), and processor() works
106
  # but if using other combinations always inspect the final input_ids to ensure they are correct. Often times you will end up with multiple <bos> tokens
107
  # if not careful, which can make the model perform poorly.
108
+ convo_string = cap_processor.apply_chat_template(convo, tokenize=False, add_generation_prompt=True)
109
+ assert isinstance(convo_string, str)
110
+ inputs = cap_processor(text=[convo_string], images=[input_image], return_tensors="pt").to('cuda')
111
+ inputs['pixel_values'] = inputs['pixel_values'].to(torch.bfloat16)
112
+ streamer = TextIteratorStreamer(cap_processor.tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
113
+ generate_kwargs = dict(
114
+ **inputs,
115
+ max_new_tokens=max_new_tokens,
116
+ do_sample=True if temperature > 0 else False,
117
+ suppress_tokens=None,
118
+ use_cache=True,
119
+ temperature=temperature if temperature > 0 else None,
120
+ top_k=None,
121
+ top_p=top_p if temperature > 0 else None,
122
+ streamer=streamer,
123
+ )
124
+ _ = cap_model.generate(**generate_kwargs)
125
+ outputs = []
126
+ for text in streamer:
127
+ outputs.append(text)
128
+ yield "".join(outputs)
 
 
 
 
 
 
 
129
 
130
  @spaces.GPU()
131
  @torch.no_grad()