fix_tool_call
#78
by
wukaixingxp
- opened
No description provided.
wukaixingxp
changed pull request status to
open
Tested with code:
import json
from transformers import Llama4ForConditionalGeneration, AutoTokenizer, AutoProcessor
import torch
import requests
from PIL import Image
# Define the tool function
def get_current_temperature(location: str) -> float:
"""
Returns the current temperature at a location.
"""
# Replace with real API call as needed
return '22.0 degrees Celsius'
def tool_weather():
return {
"type": "function",
"function": {
"name": "get_weather",
"description": "Get current temperature for a given location.",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "City and country e.g. Bogotá, Colombia"
}
},
"required": [
"location"
]
}
}
}
def test_tool(model, tokenizer,processor):
print('------------test tool------------')
# Prepare chat messages
messages = [
{"role": "user", "content": "What's the temperature in Paris right now?"}
]
# Apply chat template with tool
input_ids = tokenizer.apply_chat_template(
messages,
tools=[tool_weather()],
add_generation_prompt=True,
return_tensors="pt"
)
#input_ids = tokenizer(inputs, return_tensors="pt").input_ids.to("cuda")
print('prompt:', tokenizer.decode(input_ids[0], skip_special_tokens=False))
# Generate model output with tool call
output_ids = model.generate(input_ids=input_ids, max_new_tokens=200)
generated = tokenizer.decode(output_ids[0], skip_special_tokens=True)
print('generated:',generated.split('\n')[-1])
# Parse tool call from model output
if 'get_weather' in generated and generated.endswith(']'):
# fake tool parsing
tool_call = {'name': 'get_weather', 'arguments': {'location': 'Paris'}}
func_name = tool_call["name"]
func_args = tool_call["arguments"]
print(f"Tool call: {func_name}({func_args})")
else:
tool_call = None
# Execute the tool if parsed successfully
if tool_call:
result = get_current_temperature(**func_args)
messages.append({
"role": "assistant",
"tool_calls": [{"type": "function", "function": tool_call}]
})
messages.append({
"role": "tool",
"name": func_name,
"content": str(result)
})
# Continue generation with tool result
input_ids = tokenizer.apply_chat_template(
messages,
tools=[tool_weather()],
add_generation_prompt=True,
return_tensors="pt"
)
print('prompt2:', tokenizer.decode(input_ids[0], skip_special_tokens=False))
final_ids = model.generate(input_ids=input_ids, max_new_tokens=200)
output_ids = model.generate(input_ids=input_ids, max_new_tokens=200)
generated = tokenizer.decode(output_ids[0], skip_special_tokens=True)
print('generated:',generated.split('\n')[-1])
else:
print("No valid tool call generated by the model.")
def test_image(model, processor):
print('------------test image------------')
url1 = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg"
url2 = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png"
response = requests.get(url1, stream=True)
if response.status_code == 200:
image1 = Image.open(response.raw).convert("RGB")
else:
raise ValueError(f"Failed to fetch image from URL: {url2}")
response = requests.get(url2, stream=True)
if response.status_code == 200:
image2 = Image.open(response.raw).convert("RGB")
else:
raise ValueError(f"Failed to fetch image from URL: {url2}")
prompt = "Can you describe how these two images are similar, and how they differ?<|image|> <|image|>"
inputs = processor(images=[image1, image2], text=prompt, return_tensors="pt").to(
"cuda"
)
outputs = model.generate(**inputs, max_new_tokens=50)
# Decode and print the generated text
response = processor.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:])[0]
print("Generated Text:", response)
def test_chat(model, tokenizer,processor):
print('------------test chat------------')
messages = [
{"role": "user", "content": "Who are you?"},
]
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt", return_dict=True)
outputs = model.generate(**inputs.to(model.device), max_new_tokens=100)
outputs = tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:])
print(outputs[0])
def main():
# Load model and tokenizer
model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
revision = 'refs/pr/78'
print('loading model')
print(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision)
processor = AutoProcessor.from_pretrained(model_id, revision=revision)
model = Llama4ForConditionalGeneration.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.bfloat16,
revision=revision,
)
test_tool(model, tokenizer, processor)
test_image(model, processor)
test_chat(model, tokenizer, processor)
main()
Test log:
loading model
meta-llama/Llama-4-Scout-17B-16E-Instruct
chat_template.jinja: 100%|█████████████████████████████████████████████| 7.33k/7.33k [00:00<00:00, 56.0MB/s]
Fetching 50 files: 100%|████████████████████████████████████████████████████| 50/50 [00:01<00:00, 43.35it/s]
Loading checkpoint shards: 100%|████████████████████████████████████████████| 50/50 [00:49<00:00, 1.02it/s]
------------test tool------------
prompt: <|begin_of_text|><|header_start|>system<|header_end|>
You are a helpful assistant and an expert in function composition. You can answer general questions using your internal knowledge OR invoke functions when necessary. Follow these strict guidelines:
1. FUNCTION CALLS:
- ONLY use functions that are EXPLICITLY listed in the function list below
- If NO functions are listed (empty function list []), respond ONLY with internal knowledge or "I don't have access to [Unavailable service] information"
- If a function is not in the list, respond ONLY with internal knowledge or "I don't have access to [Unavailable service] information"
- If ALL required parameters are present AND the query EXACTLY matches a listed function's purpose: output ONLY the function call(s)
- Use exact format: [func_name1(param1=value1, param2=value2), func_name2(...)]
Examples:
CORRECT: [get_weather(location="Vancouver"), calculate_route(start="Boston", end="New York")] <- Only if get_weather and calculate_route are in function list
INCORRECT: get_weather(location="New York")
INCORRECT: Let me check the weather: [get_weather(location="New York")]
INCORRECT: [get_events(location="Singapore")] <- If function not in list
2. RESPONSE RULES:
- For pure function requests matching a listed function: ONLY output the function call(s)
- For knowledge questions: ONLY output text
- For missing parameters: ONLY request the specific missing parameters
- For unavailable services (not in function list): output ONLY with internal knowledge or "I don't have access to [Unavailable service] information". Do NOT execute a function call.
- If the query asks for information beyond what a listed function provides: output ONLY with internal knowledge about your limitations
- NEVER combine text and function calls in the same response
- NEVER suggest alternative functions when the requested service is unavailable
- NEVER create or invent new functions not listed below
3. STRICT BOUNDARIES:
- ONLY use functions from the list below - no exceptions
- NEVER use a function as an alternative to unavailable information
- NEVER call functions not present in the function list
- NEVER add explanatory text to function calls
- NEVER respond with empty brackets
- Use proper Python/JSON syntax for function calls
- Check the function list carefully before responding
4. TOOL RESPONSE HANDLING:
- When receiving tool responses: provide concise, natural language responses
- Don't repeat tool response verbatim
- Don't add supplementary information
Here is a list of functions in JSON format that you can invoke:
<|eot|><|header_start|>user<|header_end|>
What's the temperature in Paris right now?<|eot|><|header_start|>assistant<|header_end|>
/home/kaiwu/.conda/envs/llama4/lib/python3.10/site-packages/transformers/generation/utils.py:2347: UserWarning: You are calling .generate() with the `input_ids` being on a device type different than your model's device. `input_ids` is on cpu, whereas the model is on cuda. You may experience unexpected behaviors or slower generation. Please make sure that you have put `input_ids` to the correct device by calling for example input_ids = input_ids.to('cuda') before running `.generate()`.
warnings.warn(
generated: [get_weather(location="Paris")]
Tool call: get_weather({'location': 'Paris'})
prompt2: <|begin_of_text|><|header_start|>system<|header_end|>
You are a helpful assistant and an expert in function composition. You can answer general questions using your internal knowledge OR invoke functions when necessary. Follow these strict guidelines:
1. FUNCTION CALLS:
- ONLY use functions that are EXPLICITLY listed in the function list below
- If NO functions are listed (empty function list []), respond ONLY with internal knowledge or "I don't have access to [Unavailable service] information"
- If a function is not in the list, respond ONLY with internal knowledge or "I don't have access to [Unavailable service] information"
- If ALL required parameters are present AND the query EXACTLY matches a listed function's purpose: output ONLY the function call(s)
- Use exact format: [func_name1(param1=value1, param2=value2), func_name2(...)]
Examples:
CORRECT: [get_weather(location="Vancouver"), calculate_route(start="Boston", end="New York")] <- Only if get_weather and calculate_route are in function list
INCORRECT: get_weather(location="New York")
INCORRECT: Let me check the weather: [get_weather(location="New York")]
INCORRECT: [get_events(location="Singapore")] <- If function not in list
2. RESPONSE RULES:
- For pure function requests matching a listed function: ONLY output the function call(s)
- For knowledge questions: ONLY output text
- For missing parameters: ONLY request the specific missing parameters
- For unavailable services (not in function list): output ONLY with internal knowledge or "I don't have access to [Unavailable service] information". Do NOT execute a function call.
- If the query asks for information beyond what a listed function provides: output ONLY with internal knowledge about your limitations
- NEVER combine text and function calls in the same response
- NEVER suggest alternative functions when the requested service is unavailable
- NEVER create or invent new functions not listed below
3. STRICT BOUNDARIES:
- ONLY use functions from the list below - no exceptions
- NEVER use a function as an alternative to unavailable information
- NEVER call functions not present in the function list
- NEVER add explanatory text to function calls
- NEVER respond with empty brackets
- Use proper Python/JSON syntax for function calls
- Check the function list carefully before responding
4. TOOL RESPONSE HANDLING:
- When receiving tool responses: provide concise, natural language responses
- Don't repeat tool response verbatim
- Don't add supplementary information
Here is a list of functions in JSON format that you can invoke:
<|eot|><|header_start|>user<|header_end|>
What's the temperature in Paris right now?<|eot|><|header_start|>assistant<|header_end|>
[get_weather(location="Paris")]<|eot|><|header_start|>ipython<|header_end|>
"22.0 degrees Celsius"<|eot|><|header_start|>assistant<|header_end|>
generated: The current temperature in Paris is 22.0 degrees Celsius.
------------test image------------
Generated Text:
The two images share a common theme of featuring anthropomorphic animals in a natural setting. The similarities between the images include:
* Both images feature a single animal as the main subject.
* The animals are dressed in human-like clothing, with the rabbit
------------test chat------------
I'm an AI assistant designed by Meta. I'm here to answer your questions, share interesting ideas and maybe even surprise you with a fresh perspective. What's on your mind?<|eot|>
tested with vllm server, not longer need to use a new chat-template now:
VLLM_DISABLE_COMPILE_CACHE=1 vllm serve meta-llama/Llama-4-Scout-17B-16E-Instruct -tp 4 --seed 0 --revision refs/pr/78 --max-model-len=120000 --host 0.0.0.0 --port 8000 --enable-auto-tool-choice --tool-call-parser pythonic --limit-mm-per-prompt image=5 --max-num-seqs 30
log:
curl -X POST http://localhost:8000/v1/chat/completions -H "Content-Type: application/json" -d '{
"model": "meta-llama/Llama-4-Scout-17B-16E-Instruct",
"messages": [
{
"role": "user",
"content": "what is the weather in SF"
}
],
"tools": [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Retrieve the current temperature for a specified location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city, state, or country for which to fetch the temperature"
}
},
"required": [
"location"
],
"additionalProperties": false
},
"strict": true
}
}
],
"tool_choice": "auto"
}'
{"id":"chatcmpl-711d7c481edb4ba89ac340527538eec6","object":"chat.completion","created":1747944223,"model":"meta-llama/Llama-4-Scout-17B-16E-Instruct","choices":[{"index":0,"message":{"role":"assistant","reasoning_content":null,"content":null,"tool_calls":[{"id":"chatcmpl-tool-2139456ac16f4274ad29bfcc2d256eb6","type":"function","function":{"name":"get_weather","arguments":"{\"location\": \"San Francisco\"}"}}]},"logprobs":null,"finish_reason":"tool_calls","stop_reason":null}],"usage":{"prompt_tokens":650,"total_tokens":658,"completion_tokens":8,"prompt_tokens_details":null},"prompt_logprobs":null,"kv_transfer_params":null}
vllm log:
INFO 05-22 13:03:43 [logger.py:42] Received request chatcmpl-711d7c481edb4ba89ac340527538eec6: prompt: '<|begin_of_text|><|header_start|>system<|header_end|>\n\nYou are a helpful assistant and an expert in function composition. You can answer general questions using your internal knowledge OR invoke functions when necessary. Follow these strict guidelines:\n\n1. FUNCTION CALLS:\n- ONLY use functions that are EXPLICITLY listed in the function list below\n- If NO functions are listed (empty function list []), respond ONLY with internal knowledge or "I don\'t have access to [Unavailable service] information"\n- If a function is not in the list, respond ONLY with internal knowledge or "I don\'t have access to [Unavailable service] information"\n- If ALL required parameters are present AND the query EXACTLY matches a listed function\'s purpose: output ONLY the function call(s)\n- Use exact format: [func_name1(param1=value1, param2=value2), func_name2(...)]\nExamples:\nCORRECT: [get_weather(location="Vancouver"), calculate_route(start="Boston", end="New York")] <- Only if get_weather and calculate_route are in function list\nINCORRECT: get_weather(location="New York")\nINCORRECT: Let me check the weather: [get_weather(location="New York")]\nINCORRECT: [get_events(location="Singapore")] <- If function not in list\n\n2. RESPONSE RULES:\n- For pure function requests matching a listed function: ONLY output the function call(s)\n- For knowledge questions: ONLY output text\n- For missing parameters: ONLY request the specific missing parameters\n- For unavailable services (not in function list): output ONLY with internal knowledge or "I don\'t have access to [Unavailable service] information". Do NOT execute a function call.\n- If the query asks for information beyond what a listed function provides: output ONLY with internal knowledge about your limitations\n- NEVER combine text and function calls in the same response\n- NEVER suggest alternative functions when the requested service is unavailable\n- NEVER create or invent new functions not listed below\n\n3. STRICT BOUNDARIES:\n- ONLY use functions from the list below - no exceptions\n- NEVER use a function as an alternative to unavailable information\n- NEVER call functions not present in the function list\n- NEVER add explanatory text to function calls\n- NEVER respond with empty brackets\n- Use proper Python/JSON syntax for function calls\n- Check the function list carefully before responding\n\n4. TOOL RESPONSE HANDLING:\n- When receiving tool responses: provide concise, natural language responses\n- Don\'t repeat tool response verbatim\n- Don\'t add supplementary information\n\nHere is a list of functions in JSON format that you can invoke:\n[\n {\n "type": "function",\n "function": {\n "name": "get_weather",\n "description": "Retrieve the current temperature for a specified location",\n "parameters": {\n "type": "object",\n "properties": {\n "location": {\n "type": "string",\n "description": "The city, state, or country for which to fetch the temperature"\n }\n },\n "required": [\n "location"\n ],\n "additionalProperties": false\n },\n "strict": true\n }\n }\n]<|eot|><|header_start|>user<|header_end|>\n\nwhat is the weather in SF<|eot|><|header_start|>assistant<|header_end|>\n\n', params: SamplingParams(n=1, presence_penalty=0.0, frequency_penalty=0.0, repetition_penalty=1.0, temperature=0.6, top_p=0.9, top_k=0, min_p=0.0, seed=None, stop=[], stop_token_ids=[], bad_words=[], include_stop_str_in_output=False, ignore_eos=False, max_tokens=119350, min_tokens=0, logprobs=None, prompt_logprobs=None, skip_special_tokens=True, spaces_between_special_tokens=True, truncate_prompt_tokens=None, guided_decoding=None, extra_args=None), prompt_token_ids: None, prompt_embeds shape: None, lora_request: None, prompt_adapter_request: None.
If user provided a system message, we will only apply basic tool_call information instead of using the default comprehensive tool_call prompt. It is the user responsibility to change their system message to get the correct output:
BAD example can be:
curl -X POST http://localhost:8000/v1/chat/completions -H "Content-Type: application/json" -d '{
"model": "meta-llama/Llama-4-Scout-17B-16E-Instruct",
"messages": [
{
"role": "system",
"content": "you are a helpful assisant who can do tool call"
},
{
"role": "user",
"content": "what is the weather in SF"
}
],
"tools": [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Retrieve the current temperature for a specified location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city, state, or country for which to fetch the temperature"
}
},
"required": [
"location"
],
"additionalProperties": false
},
"strict": true
}
}
],
"tool_choice": "auto"
}'
{"id":"chatcmpl-1b6f2c87daba4223adab662df559e963","object":"chat.completion","created":1747944288,"model":"meta-llama/Llama-4-Scout-17B-16E-Instruct","choices":[{"index":0,"message":{"role":"assistant","reasoning_content":null,"content":null,"tool_calls":[{"id":"chatcmpl-tool-b397d537d4b84dada3d650690e46411b","type":"function","function":{"name":"get_weather","arguments":"{\"location\": \"San Francisco\"}"}}]},"logprobs":null,"finish_reason":"tool_calls","stop_reason":null}],"usage":{"prompt_tokens":190,"total_tokens":198,"completion_tokens":8,"prompt_tokens_details":null},"prompt_logprobs":null,"kv_transfer_params":null}
vllm log:
INFO 05-22 13:04:48 [logger.py:42] Received request chatcmpl-1b6f2c87daba4223adab662df559e963: prompt: '<|begin_of_text|><|header_start|>system<|header_end|>\n\nyou are a helpful assisant who can do tool call\nHere is a list of functions in JSON format that you can invoke. Use exact format: [func_name1(param1=value1, param2=value2), func_name2(...)]\n[\n {\n "type": "function",\n "function": {\n "name": "get_weather",\n "description": "Retrieve the current temperature for a specified location",\n "parameters": {\n "type": "object",\n "properties": {\n "location": {\n "type": "string",\n "description": "The city, state, or country for which to fetch the temperature"\n }\n },\n "required": [\n "location"\n ],\n "additionalProperties": false\n },\n "strict": true\n }\n }\n]<|eot|><|header_start|>user<|header_end|>\n\nwhat is the weather in SF<|eot|><|header_start|>assistant<|header_end|>\n\n', params: SamplingParams(n=1, presence_penalty=0.0, frequency_penalty=0.0, repetition_penalty=1.0, temperature=0.6, top_p=0.9, top_k=0, min_p=0.0, seed=None, stop=[], stop_token_ids=[], bad_words=[], include_stop_str_in_output=False, ignore_eos=False, max_tokens=119810, min_tokens=0, logprobs=None, prompt_logprobs=None, skip_special_tokens=True, spaces_between_special_tokens=True, truncate_prompt_tokens=None, guided_decoding=None, extra_args=None), prompt_token_ids: None, prompt_embeds shape: None, lora_request: None, prompt_adapter_request: None.
If user provided a system message, we will only apply basic tool_call information instead of using the default comprehensive tool_call prompt. It is the user responsibility to change their system message to get the correct output
curl -X POST http://localhost:8000/v1/chat/completions -H "Content-Type: application/json" -d '{
"model": "meta-llama/Llama-4-Scout-17B-16E-Instruct",
"messages": [
{
"role": "system",
"content": "you are a helpful assisant."
},
{
"role": "user",
"content": "who are you?"
}
]
}'
{"id":"chatcmpl-b9b22fd0ea154c588bfc4883ba4a6c7e","object":"chat.completion","created":1747933707,"model":"meta-llama/Llama-4-Scout-17B-16E-Instruct","choices":[{"index":0,"message":{"role":"assistant","reasoning_content":null,"content":"I'm an AI assistant designed to help users with a wide range of tasks and provide information on various topics. I'm here to assist you with any questions or problems you might have, so feel free to ask me anything! \n\nI can help with:\n\n* Answering questions on various subjects, such as science, history, technology, and more\n* Providing definitions and explanations\n* Offering suggestions and recommendations\n* Assisting with language-related tasks, such as translation and grammar\n* And many more!\n\nWhat's on your mind? How can I help you today?","tool_calls":[]},"logprobs":null,"finish_reason":"stop","stop_reason":null}],"usage":{"prompt_tokens":26,"total_tokens":138,"completion_tokens":112,"prompt_tokens_details":null},"prompt_logprobs":null,"kv_transfer_params":null}
vllm-log
INFO 05-22 10:08:27 [logger.py:42] Received request chatcmpl-b9b22fd0ea154c588bfc4883ba4a6c7e: prompt: '<|begin_of_text|><|header_start|>system<|header_end|>\n\nyou are a helpful assisant.<|eot|><|header_start|>user<|header_end|>\n\nwho are you?<|eot|><|header_start|>assistant<|header_end|>\n\n', params: SamplingParams(n=1, presence_penalty=0.0, frequency_penalty=0.0, repetition_penalty=1.0, temperature=0.6, top_p=0.9, top_k=0, min_p=0.0, seed=None, stop=[], stop_token_ids=[], bad_words=[], include_stop_str_in_output=False, ignore_eos=False, max_tokens=119974, min_tokens=0, logprobs=None, prompt_logprobs=None, skip_special_tokens=True, spaces_between_special_tokens=True, truncate_prompt_tokens=None, guided_decoding=None, extra_args=None), prompt_token_ids: None, prompt_embeds shape: None, lora_request: None, prompt_adapter_request: None.
vontimitta
changed pull request status to
merged