File size: 7,487 Bytes
4737b79
1
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: chatbot_core_components"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "os.mkdir('files')\n", "!wget -q -O files/audio.wav https://github.com/gradio-app/gradio/raw/main/demo/chatbot_core_components/files/audio.wav\n", "!wget -q -O files/avatar.png https://github.com/gradio-app/gradio/raw/main/demo/chatbot_core_components/files/avatar.png\n", "!wget -q -O files/sample.txt https://github.com/gradio-app/gradio/raw/main/demo/chatbot_core_components/files/sample.txt\n", "!wget -q -O files/world.mp4 https://github.com/gradio-app/gradio/raw/main/demo/chatbot_core_components/files/world.mp4"]}, {"cell_type": "code", "execution_count": null, "id": "44380577570523278879349135829904343037", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import os\n", "import plotly.express as px\n", "\n", "# Chatbot demo with multimodal input (text, markdown, LaTeX, code blocks, image, audio, & video). Plus shows support for streaming text.\n", "\n", "\n", "def random_plot():\n", "    df = px.data.iris()\n", "    fig = px.scatter(\n", "        df,\n", "        x=\"sepal_width\",\n", "        y=\"sepal_length\",\n", "        color=\"species\",\n", "        size=\"petal_length\",\n", "        hover_data=[\"petal_width\"],\n", "    )\n", "    return fig\n", "\n", "\n", "def print_like_dislike(x: gr.LikeData):\n", "    print(x.index, x.value, x.liked)\n", "\n", "\n", "def random_bokeh_plot():\n", "    from bokeh.models import ColumnDataSource, Whisker\n", "    from bokeh.plotting import figure\n", "    from bokeh.sampledata.autompg2 import autompg2 as df\n", "    from bokeh.transform import factor_cmap, jitter, factor_mark\n", "\n", "    classes = list(sorted(df[\"class\"].unique()))\n", "\n", "    p = figure(\n", "        height=400,\n", "        x_range=classes,\n", "        background_fill_color=\"#efefef\",\n", "        title=\"Car class vs HWY mpg with quintile ranges\",\n", "    )\n", "    p.xgrid.grid_line_color = None\n", "\n", "    g = df.groupby(\"class\")\n", "    upper = g.hwy.quantile(0.80)\n", "    lower = g.hwy.quantile(0.20)\n", "    source = ColumnDataSource(data=dict(base=classes, upper=upper, lower=lower))\n", "\n", "    error = Whisker(\n", "        base=\"base\",\n", "        upper=\"upper\",\n", "        lower=\"lower\",\n", "        source=source,\n", "        level=\"annotation\",\n", "        line_width=2,\n", "    )\n", "    error.upper_head.size = 20\n", "    error.lower_head.size = 20\n", "    p.add_layout(error)\n", "\n", "    p.circle(\n", "        jitter(\"class\", 0.3, range=p.x_range),\n", "        \"hwy\",\n", "        source=df,\n", "        alpha=0.5,\n", "        size=13,\n", "        line_color=\"white\",\n", "        color=factor_cmap(\"class\", \"Light6\", classes),\n", "    )\n", "    return p\n", "\n", "\n", "def random_matplotlib_plot():\n", "    import numpy as np\n", "    import pandas as pd\n", "    import matplotlib.pyplot as plt\n", "\n", "    countries = [\"USA\", \"Canada\", \"Mexico\", \"UK\"]\n", "    months = [\"January\", \"February\", \"March\", \"April\", \"May\"]\n", "    m = months.index(\"January\")\n", "    r = 3.2\n", "    start_day = 30 * m\n", "    final_day = 30 * (m + 1)\n", "    x = np.arange(start_day, final_day + 1)\n", "    pop_count = {\"USA\": 350, \"Canada\": 40, \"Mexico\": 300, \"UK\": 120}\n", "    df = pd.DataFrame({\"day\": x})\n", "    for country in countries:\n", "        df[country] = x ** (r) * (pop_count[country] + 1)\n", "\n", "    fig = plt.figure()\n", "    plt.plot(df[\"day\"], df[countries].to_numpy())\n", "    plt.title(\"Outbreak in \" + \"January\")\n", "    plt.ylabel(\"Cases\")\n", "    plt.xlabel(\"Days since Day 0\")\n", "    plt.legend(countries)\n", "    return fig\n", "\n", "\n", "def add_message(history, message):\n", "    for x in message[\"files\"]:\n", "        history.append(((x,), None))\n", "    if message[\"text\"] is not None:\n", "        history.append((message[\"text\"], None))\n", "    return history, gr.MultimodalTextbox(value=None, interactive=False)\n", "\n", "\n", "def bot(history, response_type):\n", "    if response_type == \"plot\":\n", "        history[-1][1] = gr.Plot(random_plot())\n", "    elif response_type == \"bokeh_plot\":\n", "        history[-1][1] = gr.Plot(random_bokeh_plot())\n", "    elif response_type == \"matplotlib_plot\":\n", "        history[-1][1] = gr.Plot(random_matplotlib_plot())\n", "    elif response_type == \"gallery\":\n", "        history[-1][1] = gr.Gallery(\n", "            [os.path.join(\"files\", \"avatar.png\"), os.path.join(\"files\", \"avatar.png\")]\n", "        )\n", "    elif response_type == \"image\":\n", "        history[-1][1] = gr.Image(os.path.join(\"files\", \"avatar.png\"))\n", "    elif response_type == \"video\":\n", "        history[-1][1] = gr.Video(os.path.join(\"files\", \"world.mp4\"))\n", "    elif response_type == \"audio\":\n", "        history[-1][1] = gr.Audio(os.path.join(\"files\", \"audio.wav\"))\n", "    elif response_type == \"audio_file\":\n", "        history[-1][1] = (os.path.join(\"files\", \"audio.wav\"), \"description\")\n", "    elif response_type == \"image_file\":\n", "        history[-1][1] = (os.path.join(\"files\", \"avatar.png\"), \"description\")\n", "    elif response_type == \"video_file\":\n", "        history[-1][1] = (os.path.join(\"files\", \"world.mp4\"), \"description\")\n", "    elif response_type == \"txt_file\":\n", "        history[-1][1] = (os.path.join(\"files\", \"sample.txt\"), \"description\")\n", "    else:\n", "        history[-1][1] = \"Cool!\"\n", "    return history\n", "\n", "\n", "fig = random_plot()\n", "\n", "with gr.Blocks(fill_height=True) as demo:\n", "    chatbot = gr.Chatbot(\n", "        elem_id=\"chatbot\",\n", "        bubble_full_width=False,\n", "        scale=1,\n", "    )\n", "    response_type = gr.Radio(\n", "        [\n", "            \"audio_file\",\n", "            \"image_file\",\n", "            \"video_file\",\n", "            \"txt_file\",\n", "            \"plot\",\n", "            \"matplotlib_plot\",\n", "            \"bokeh_plot\",\n", "            \"image\",\n", "            \"text\",\n", "            \"gallery\",\n", "            \"video\",\n", "            \"audio\",\n", "        ],\n", "        value=\"text\",\n", "        label=\"Response Type\",\n", "    )\n", "\n", "    chat_input = gr.MultimodalTextbox(\n", "        interactive=True,\n", "        placeholder=\"Enter message or upload file...\",\n", "        show_label=False,\n", "    )\n", "\n", "    chat_msg = chat_input.submit(\n", "        add_message, [chatbot, chat_input], [chatbot, chat_input]\n", "    )\n", "    bot_msg = chat_msg.then(\n", "        bot, [chatbot, response_type], chatbot, api_name=\"bot_response\"\n", "    )\n", "    bot_msg.then(lambda: gr.MultimodalTextbox(interactive=True), None, [chat_input])\n", "\n", "    chatbot.like(print_like_dislike, None, None)\n", "\n", "demo.queue()\n", "if __name__ == \"__main__\":\n", "    demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}