umar141 commited on
Commit
66fcb6e
·
verified ·
1 Parent(s): e515527

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -19
app.py CHANGED
@@ -1,25 +1,30 @@
1
- from transformers import GemmaForCausalLM, AutoTokenizer
 
2
 
3
- # Load tokenizer
4
  tokenizer = AutoTokenizer.from_pretrained("umar141/Gemma_1B_Baro_v2_vllm")
 
5
 
6
- # Load model
7
- model = GemmaForCausalLM.from_pretrained(
8
- "umar141/Gemma_1B_Baro_v2_vllm",
9
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
10
- device_map="auto"
11
- )
12
 
13
- # Tokenize prompt
14
- input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)
 
15
 
16
- # Generate
17
- outputs = model.generate(
18
- input_ids=input_ids,
19
- max_new_tokens=200,
20
- do_sample=True,
21
- top_p=0.9,
22
- temperature=0.7,
23
- )
 
 
 
 
 
 
 
24
 
25
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
1
+ import streamlit as st
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
 
4
+ # Load the model and tokenizer
5
  tokenizer = AutoTokenizer.from_pretrained("umar141/Gemma_1B_Baro_v2_vllm")
6
+ model = AutoModelForCausalLM.from_pretrained("umar141/Gemma_1B_Baro_v2_vllm")
7
 
8
+ # Streamlit page configuration
9
+ st.set_page_config(page_title="Gemma-based Chatbot", page_icon=":robot:")
 
 
 
 
10
 
11
+ # Introduction text
12
+ st.title("Gemma-based Chatbot")
13
+ st.write("This is a chatbot powered by a fine-tuned Gemma model.")
14
 
15
+ # User input
16
+ user_input = st.text_input("Ask me anything:")
17
+
18
+ # Generate response when the user inputs a query
19
+ if user_input:
20
+ # Tokenize input and generate model response
21
+ inputs = tokenizer.encode(user_input, return_tensors="pt")
22
+ outputs = model.generate(inputs, max_length=150, num_return_sequences=1, pad_token_id=tokenizer.eos_token_id)
23
+
24
+ # Decode the response
25
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
26
+
27
+ # Display the response
28
+ st.write("AI Response:")
29
+ st.write(response)
30