sounar's picture
Update app.py
a0ba541 verified
raw
history blame
2.21 kB
# Retrieve the token from environment variables
#api_token = os.getenv("HF_TOKEN").strip()
import torch
from flask import Flask, request, jsonify
from transformers import AutoModel, AutoTokenizer, BitsAndBytesConfig
from PIL import Image
import io
import base64
app = Flask(__name__)
# Quantization configuration
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.float16
)
# Load model
model = AutoModel.from_pretrained(
"ContactDoctor/Bio-Medical-MultiModal-Llama-3-8B-V1",
quantization_config=bnb_config,
device_map="auto",
torch_dtype=torch.float16,
trust_remote_code=True,
attn_implementation="flash_attention_2"
)
tokenizer = AutoTokenizer.from_pretrained(
"ContactDoctor/Bio-Medical-MultiModal-Llama-3-8B-V1",
trust_remote_code=True
)
def decode_base64_image(base64_string):
# Decode base64 image
image_data = base64.b64decode(base64_string)
image = Image.open(io.BytesIO(image_data)).convert('RGB')
return image
@app.route('/analyze', methods=['POST'])
def analyze_input():
data = request.json
question = data.get('question', '')
base64_image = data.get('image', None)
try:
# Process with image if provided
if base64_image:
image = decode_base64_image(base64_image)
inputs = model.prepare_inputs_for_generation(
input_ids=tokenizer(question, return_tensors="pt").input_ids,
images=[image]
)
outputs = model.generate(**inputs, max_new_tokens=256)
else:
# Text-only processing
inputs = tokenizer(question, return_tensors="pt")
outputs = model.generate(**inputs, max_new_tokens=256)
# Decode response
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
return jsonify({
'status': 'success',
'response': response
})
except Exception as e:
return jsonify({
'status': 'error',
'message': str(e)
}), 500
if __name__ == '__main__':
app.run(debug=True)