File size: 2,205 Bytes
882bd69
fa9231d
72ff248
832ce7b
72ff248
a0ba541
72ff248
a0ba541
 
 
832ce7b
a0ba541
2629ae5
a0ba541
72ff248
a0ba541
 
 
 
fa9231d
0b2a88c
a0ba541
72ff248
a0ba541
 
 
 
 
 
fa9231d
882bd69
72ff248
a0ba541
72ff248
 
9998c92
a0ba541
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2629ae5
882bd69
a0ba541
 
72ff248
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81

# Retrieve the token from environment variables
#api_token = os.getenv("HF_TOKEN").strip()

import torch
from flask import Flask, request, jsonify
from transformers import AutoModel, AutoTokenizer, BitsAndBytesConfig
from PIL import Image
import io
import base64

app = Flask(__name__)

# Quantization configuration
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True, 
    bnb_4bit_quant_type="nf4", 
    bnb_4bit_use_double_quant=True, 
    bnb_4bit_compute_dtype=torch.float16
)

# Load model
model = AutoModel.from_pretrained(
    "ContactDoctor/Bio-Medical-MultiModal-Llama-3-8B-V1", 
    quantization_config=bnb_config, 
    device_map="auto", 
    torch_dtype=torch.float16, 
    trust_remote_code=True, 
    attn_implementation="flash_attention_2"
)

tokenizer = AutoTokenizer.from_pretrained(
    "ContactDoctor/Bio-Medical-MultiModal-Llama-3-8B-V1", 
    trust_remote_code=True
)

def decode_base64_image(base64_string):
    # Decode base64 image
    image_data = base64.b64decode(base64_string)
    image = Image.open(io.BytesIO(image_data)).convert('RGB')
    return image

@app.route('/analyze', methods=['POST'])
def analyze_input():
    data = request.json
    question = data.get('question', '')
    base64_image = data.get('image', None)

    try:
        # Process with image if provided
        if base64_image:
            image = decode_base64_image(base64_image)
            inputs = model.prepare_inputs_for_generation(
                input_ids=tokenizer(question, return_tensors="pt").input_ids,
                images=[image]
            )
            outputs = model.generate(**inputs, max_new_tokens=256)
        else:
            # Text-only processing
            inputs = tokenizer(question, return_tensors="pt")
            outputs = model.generate(**inputs, max_new_tokens=256)

        # Decode response
        response = tokenizer.decode(outputs[0], skip_special_tokens=True)

        return jsonify({
            'status': 'success', 
            'response': response
        })

    except Exception as e:
        return jsonify({
            'status': 'error', 
            'message': str(e)
        }), 500


if __name__ == '__main__':
    app.run(debug=True)