File size: 2,082 Bytes
02e32ba
cb89037
 
 
02e32ba
cb89037
 
c960062
cb89037
 
 
 
 
 
 
 
 
 
c960062
2b273ea
77b0f0f
cb89037
c960062
cb89037
c960062
cb89037
 
 
 
 
 
 
 
 
 
c960062
 
cb89037
c960062
cb89037
 
 
 
 
 
 
 
 
 
c960062
 
 
 
2b273ea
77b0f0f
cb89037
c960062
cb89037
 
c960062
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import gradio as gr
import concurrent.futures

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load models and tokenizers
def load_model(name):
    tokenizer = AutoTokenizer.from_pretrained(name)
    model = AutoModelForCausalLM.from_pretrained(name)
    
    # Define pad token explicitly
    tokenizer.pad_token = tokenizer.eos_token
    model.config.pad_token_id = tokenizer.pad_token_id
    
    return tokenizer, model.to(device)

tokenizer1, model1 = load_model("google/gemma-3-1b-it")
tokenizer2, model2 = load_model("facebook/opt-125m")
tokenizer3, model3 = load_model("gpt2")

# Generation function
def generate_response(model, tokenizer, prompt):
    inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True).to(device)
    outputs = model.generate(
        inputs["input_ids"],
        attention_mask=inputs["attention_mask"],
        max_length=100,
        pad_token_id=tokenizer.pad_token_id,
        do_sample=True,
        temperature=0.7,
        top_p=0.9
    )
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

# Multi-agent handler
def multi_agent_chat(user_input):
    with concurrent.futures.ThreadPoolExecutor() as executor:
        futures = [
            executor.submit(generate_response, model1, tokenizer1, user_input),
            executor.submit(generate_response, model2, tokenizer2, user_input),
            executor.submit(generate_response, model3, tokenizer3, user_input)
        ]
        results = [f.result() for f in futures]
    return results

# Gradio Interface
interface = gr.Interface(
    fn=multi_agent_chat,
    inputs=gr.Textbox(lines=2, placeholder="Ask something..."),
    outputs=[
        gr.Textbox(label="Agent 1 (google/gemma-3-1b-it)"),
        gr.Textbox(label="Agent 2 (facebook/opt-125m)"),
        gr.Textbox(label="Agent 3 (GPT-2)")
    ],
    title="3-Agent AI Chatbot",
    description="Three GPT-style agents respond to your input in parallel!"
)

interface.launch()