Spaces:
Running
Running
| import os | |
| import spaces | |
| import torch | |
| import gradio as gr | |
| # cpu | |
| zero = torch.Tensor([0]).cuda() | |
| print(zero.device) # <-- 'cpu' 🤔 | |
| # gpu | |
| def greet(user): | |
| # print(zero.device) # <-- 'cuda:0' 🤗 | |
| from vllm import SamplingParams, LLM | |
| from transformers.utils import move_cache | |
| from huggingface_hub import snapshot_download, login | |
| LLM_MODEL_ID = "mistral-community/Mistral-7B-v0.2" | |
| fp = snapshot_download(LLM_MODEL_ID) | |
| move_cache() | |
| model = LLM(fp) | |
| sampling_params = dict( | |
| temperature = 0.3, | |
| ignore_eos = False, | |
| max_tokens = int(512 * 2) | |
| ) | |
| sampling_params = SamplingParams(**sampling_params) | |
| prompts = [user] | |
| model_outputs = model.generate(prompts, sampling_params) | |
| generations = [] | |
| for output in model_outputs: | |
| for outputs in output.outputs: | |
| generations.append(outputs.text) | |
| return generations[0] | |
| ## make predictions via api ## | |
| # https://www.gradio.app/guides/getting-started-with-the-python-client#connecting-a-general-gradio-app | |
| demo = gr.Interface(fn=greet, inputs=gr.Text(), outputs=gr.Text()) | |
| demo.launch(share=True) |