Spaces:
Running
Running
| import os | |
| import spaces | |
| import torch | |
| import gradio as gr | |
| # cpu | |
| zero = torch.Tensor([0]).cuda() | |
| print(zero.device) # <-- 'cpu' 🤔 | |
| # gpu | |
| model = None | |
| def greet(prompts, separator): | |
| # print(zero.device) # <-- 'cuda:0' 🤗 | |
| from vllm import SamplingParams, LLM | |
| from transformers.utils import move_cache | |
| from huggingface_hub import snapshot_download, login | |
| global model | |
| if model is None: | |
| LLM_MODEL_ID = "DoctorSlimm/trim-music-31" | |
| # LLM_MODEL_ID = "mistral-community/Mistral-7B-v0.2" | |
| # LLM_MODEL_ID = "mistralai/Mistral-7B-Instruct-v0.2" | |
| os.environ['HF_HUB_ENABLE_HF_TRANSFER'] = '1' | |
| fp = snapshot_download(LLM_MODEL_ID, token=os.getenv('HF_TOKEN'), revision='main') | |
| move_cache() | |
| model = LLM(fp) | |
| sampling_params = dict( | |
| temperature = 0.01, | |
| ignore_eos = False, | |
| max_tokens = int(512 * 2) | |
| ) | |
| sampling_params = SamplingParams(**sampling_params) | |
| multi_prompt = False | |
| separator = separator.strip() | |
| if separator in prompts: | |
| multi_prompt = True | |
| prompts = prompts.split(separator) | |
| else: | |
| prompts = [prompts] | |
| for idx, pt in enumerate(prompts): | |
| print() | |
| print(f'[{idx}]:') | |
| print(pt) | |
| model_outputs = model.generate(prompts, sampling_params) | |
| generations = [] | |
| for output in model_outputs: | |
| for outputs in output.outputs: | |
| generations.append(outputs.text) | |
| if multi_prompt: | |
| return separator.join(generations) | |
| return generations[0] | |
| ## make predictions via api ## | |
| # https://www.gradio.app/guides/getting-started-with-the-python-client#connecting-a-general-gradio-app | |
| demo = gr.Interface( | |
| fn=greet, | |
| inputs=[ | |
| gr.Text( | |
| value='hello sir!<SEP>bonjour madame...', | |
| placeholder='hello sir!<SEP>bonjour madame...', | |
| label='list of prompts separated by separator' | |
| ), | |
| gr.Text( | |
| value='<SEP>', | |
| placeholder='<SEP>', | |
| label='separator for your prompts' | |
| )], | |
| outputs=gr.Text() | |
| ) | |
| demo.launch(share=True) |