techindia2025 commited on
Commit
b80af5b
·
verified ·
1 Parent(s): fe95d2b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -0
app.py CHANGED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
+ import spaces
5
+
6
+ # Model name
7
+ model_name = "medalpaca/medalpaca-7b"
8
+
9
+ # Load tokenizer and model globally for efficiency
10
+ print(f"CUDA available: {torch.cuda.is_available()}")
11
+ if torch.cuda.is_available():
12
+ print(f"GPU device count: {torch.cuda.device_count()}")
13
+ print(f"GPU device name: {torch.cuda.get_device_name(0)}")
14
+
15
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
16
+ model = AutoModelForCausalLM.from_pretrained(
17
+ model_name,
18
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
19
+ device_map="auto", # Use GPU if available
20
+ load_in_8bit=torch.cuda.is_available() # 8-bit quantization for GPU
21
+ )
22
+
23
+ def format_prompt(message, chat_history):
24
+ prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
25
+ if chat_history:
26
+ prompt += "Previous conversation:\n"
27
+ for turn in chat_history:
28
+ user_message, assistant_message = turn
29
+ prompt += f"Human: {user_message}\nAssistant: {assistant_message}\n\n"
30
+ prompt += f"Human: {message}\nAssistant:"
31
+ return prompt
32
+
33
+ @spaces.GPU # <--- This is REQUIRED for ZeroGPU!
34
+ def generate_response(message, chat_history):
35
+ prompt = format_prompt(message, chat_history)
36
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
37
+ with torch.no_grad():
38
+ generation_output = model.generate(
39
+ input_ids=inputs.input_ids,
40
+ attention_mask=inputs.attention_mask,
41
+ max_new_tokens=512,
42
+ temperature=0.7,
43
+ top_p=0.9,
44
+ do_sample=True,
45
+ )
46
+ full_output = tokenizer.decode(generation_output[0], skip_special_tokens=True)
47
+ response = full_output.split("Assistant:")[-1].strip()
48
+ chat_history.append((message, response))
49
+ return "", chat_history
50
+
51
+ with gr.Blocks(css="footer {visibility: hidden}") as demo:
52
+ gr.Markdown("# MedAlpaca Medical Chatbot")
53
+ gr.Markdown("A specialized medical chatbot powered by MedAlpaca-7B.")
54
+ gr.Markdown("Ask medical questions and get responses from a model trained on medical data.")
55
+
56
+ chatbot = gr.Chatbot(type="messages")
57
+ msg = gr.Textbox(placeholder="Type your medical question here...")
58
+ clear = gr.Button("Clear")
59
+
60
+ msg.submit(generate_response, [msg, chatbot], [msg, chatbot]) # Pass GPU-decorated function!
61
+ clear.click(lambda: None, None, chatbot, queue=False)
62
+
63
+ if __name__ == "__main__":
64
+ print("Starting Gradio app...")
65
+ demo.launch(server_name="0.0.0.0")