LukasHug commited on
Commit
dda3db7
·
verified ·
1 Parent(s): 98977e3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -44
app.py CHANGED
@@ -130,11 +130,7 @@ class SimpleConversation:
130
 
131
  default_conversation = SimpleConversation()
132
 
133
- # Model and processor storage
134
- tokenizer = None
135
- model = None
136
- processor = None
137
- context_len = 8048
138
 
139
 
140
  def wrap_taxonomy(text):
@@ -150,42 +146,6 @@ enable_btn = gr.Button(interactive=True)
150
  disable_btn = gr.Button(interactive=False)
151
 
152
 
153
- # Model loading function
154
- @spaces.GPU
155
- def load_model(model_path):
156
- global tokenizer, model, processor, context_len
157
-
158
- logger.info(f"Loading model: {model_path}")
159
-
160
- try:
161
- # Check if it's a Qwen model
162
- if "qwenguard" in model_path.lower():
163
- model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
164
- model_path,
165
- torch_dtype="auto",
166
- device_map="auto"
167
- )
168
- processor = AutoProcessor.from_pretrained(model_path)
169
- tokenizer = processor.tokenizer
170
-
171
- # Otherwise assume it's a LlavaGuard model
172
- else:
173
- model = LlavaOnevisionForConditionalGeneration.from_pretrained(
174
- model_path,
175
- torch_dtype="auto",
176
- device_map="auto",
177
- trust_remote_code=True
178
- )
179
- tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
180
- processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
181
-
182
- context_len = getattr(model.config, "max_position_embeddings", 8048)
183
- logger.info(f"Model {model_path} loaded successfully")
184
- return # Remove return value to avoid Gradio warnings
185
-
186
- except Exception as e:
187
- logger.error(f"Error loading model {model_path}: {str(e)}")
188
- return # Remove return value to avoid Gradio warnings
189
 
190
 
191
  def get_conv_log_filename():
@@ -198,8 +158,6 @@ def get_conv_log_filename():
198
  # Inference function
199
  @spaces.GPU
200
  def run_inference(prompt, image, temperature=0.2, top_p=0.95, max_tokens=512):
201
- global model, tokenizer, processor
202
-
203
  if model is None or processor is None:
204
  return "Model not loaded. Please wait for model to initialize."
205
  try:
@@ -622,7 +580,33 @@ if api_key:
622
  logger.info("Logged in to Hugging Face Hub")
623
 
624
  # Load model at startup
625
- load_model(DEFAULT_MODEL)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
626
 
627
  demo = build_demo(embed_mode=args.embed, cur_dir='./', concurrency_count=args.concurrency_count)
628
  demo.queue(
 
130
 
131
  default_conversation = SimpleConversation()
132
 
133
+
 
 
 
 
134
 
135
 
136
  def wrap_taxonomy(text):
 
146
  disable_btn = gr.Button(interactive=False)
147
 
148
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
 
150
 
151
  def get_conv_log_filename():
 
158
  # Inference function
159
  @spaces.GPU
160
  def run_inference(prompt, image, temperature=0.2, top_p=0.95, max_tokens=512):
 
 
161
  if model is None or processor is None:
162
  return "Model not loaded. Please wait for model to initialize."
163
  try:
 
580
  logger.info("Logged in to Hugging Face Hub")
581
 
582
  # Load model at startup
583
+ model_path = DEFAULT_MODEL
584
+ logger.info(f"Loading model: {model_path}")
585
+ # Check if it's a Qwen model
586
+ if "qwenguard" in model_path.lower():
587
+ @spaces.GPU
588
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
589
+ model_path,
590
+ torch_dtype="auto",
591
+ device_map="auto"
592
+ )
593
+ processor = AutoProcessor.from_pretrained(model_path)
594
+ tokenizer = processor.tokenizer
595
+
596
+ # Otherwise assume it's a LlavaGuard model
597
+ else:
598
+ @spaces.GPU
599
+ model = LlavaOnevisionForConditionalGeneration.from_pretrained(
600
+ model_path,
601
+ torch_dtype="auto",
602
+ device_map="auto",
603
+ trust_remote_code=True
604
+ )
605
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
606
+ processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
607
+
608
+ context_len = getattr(model.config, "max_position_embeddings", 8048)
609
+ logger.info(f"Model {model_path} loaded successfully")
610
 
611
  demo = build_demo(embed_mode=args.embed, cur_dir='./', concurrency_count=args.concurrency_count)
612
  demo.queue(