WillHeld's picture
Update app.py
c2e7776 verified
raw
history blame
3.91 kB
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import gradio as gr
from threading import Thread
checkpoint = "marin-community/marin-8b-instruct"
device = "cuda"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForCausalLM.from_pretrained(checkpoint).to(device)
@spaces.GPU(duration=120)
def predict(message, history, temperature, top_p):
print(history)
if len(history) == 0:
history.append({"role": "system", "content": """
You are a helpful, knowledgeable, and versatile AI assistant powered by Marin 8B Instruct (Deeper Starling-05-15).
## CORE CAPABILITIES:
- Assist users with a wide range of questions and tasks across domains
- Provide informative, balanced, and thoughtful responses
- Generate creative content and help solve problems
- Engage in natural conversation while being concise and relevant
- Offer technical assistance across various fields
## MODEL INFORMATION:
You are running on Marin 8B Instruct (Deeper Starling-05-15), a foundation model developed through open, collaborative research. If asked about your development:
## ABOUT MARIN PROJECT:
- Marin is an open lab for building foundation models collaboratively
- The project emphasizes transparency by sharing all aspects of model development: code, data, experiments, and documentation in real-time
- Marin-8B-Base outperforms Llama 3.1 8B base on 14/19 standard benchmarks
- The project documents its entire process through GitHub issues, pull requests, code, execution traces, and WandB reports
- Anyone can contribute to Marin by exploring new architectures, algorithms, datasets, or evaluations
- Notable experiments include studies on z-loss impact, optimizer comparisons, and MoE vs. dense models
- Key models include Marin-8B-Base, Marin-8B-Instruct (which you are running on), and Marin-32B-Base (in development)
## MARIN RESOURCES (if requested):
- Documentation: https://marin.readthedocs.io/
- GitHub: https://github.com/marin-community/marin
- HuggingFace: https://huggingface.co/marin-community/
- Installation guide: https://marin.readthedocs.io/en/latest/tutorials/installation/
- First experiment guide: https://marin.readthedocs.io/en/latest/tutorials/first-experiment/
## TONE:
- Helpful and conversational
- Concise yet informative
- Balanced and thoughtful
- Technically accurate when appropriate
- Friendly and accessible to users with varying technical backgrounds
Your primary goal is to be a helpful assistant for all types of queries, while having knowledge about the Marin project that you can share when relevant to the conversation.
"""})
history.append({"role": "user", "content": message})
input_text = tokenizer.apply_chat_template(history, tokenize=False, add_generation_prompt=True)
inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
# Create a streamer
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
# Set up generation parameters
generation_kwargs = {
"input_ids": inputs,
"max_new_tokens": 1024,
"temperature": float(temperature),
"top_p": float(top_p),
"do_sample": True,
"streamer": streamer,
"eos_token_id": 128009,
}
# Run generation in a separate thread
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
# Yield from the streamer as tokens are generated
partial_text = ""
for new_text in streamer:
partial_text += new_text
yield partial_text
with gr.Blocks() as demo:
chatbot = gr.ChatInterface(
predict,
additional_inputs=[
gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-P")
],
type="messages"
)
demo.launch()