{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "private_outputs": true, "provenance": [], "authorship_tag": "ABX9TyPmvDoFpmwAf1QFBJZy7XSQ" }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" } }, "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "id": "Rli_enT6lBDT" }, "outputs": [], "source": [ "##%%\n", "import os\n", "import torch\n", "import random\n", "import numpy as np\n", "import argparse\n", "import json\n", "import cohere\n", "from openai import OpenAI\n" ] }, { "cell_type": "code", "source": [ "##%%\n", "from tqdm import tqdm\n", "\n", "from collections import Counter\n", "\n", "from transformers import LlamaForCausalLM, AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM\n", "import hashlib\n", "\n", "from textgames import GAME_NAMES, GAME_IDS, LEVELS, LEVELS_HIDDEN, LEVEL_IDS, new_game\n" ], "metadata": { "id": "dp1F32B8oSfD" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "##%%\n", "gen_model_checkpoint = \"google/gemma-2-9b-it\"\n", "quantize = True" ], "metadata": { "id": "jZF8bkUcojTX" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "kwargs = {\n", " \"device_map\": \"auto\",\n", "} if quantize else {}" ], "metadata": { "id": "VAF5sR9arYzS" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "##%%\n", "gen_model = AutoModelForCausalLM.from_pretrained(gen_model_checkpoint, **kwargs)\n", "tokenizer = AutoTokenizer.from_pretrained(gen_model_checkpoint, **kwargs)" ], "metadata": { "id": "tzqldl8ooRVL" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "gen_model.device" ], "metadata": { "id": "FeBUXdkWsWrL" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "def get_gemma_response(text):\n", " # global gen_model, tokenizer\n", " messages = [\n", " {\"role\": \"user\", \"content\": text},\n", " ]\n", "\n", " input_ids = tokenizer.apply_chat_template(\n", " messages,\n", " add_generation_prompt=True,\n", " return_tensors=\"pt\"\n", " ).to(gen_model.device)\n", "\n", " terminators = [\n", " tokenizer.eos_token_id,\n", " tokenizer.convert_tokens_to_ids(\"<|eot_id|>\")\n", " ]\n", "\n", " outputs = gen_model.generate(\n", " input_ids,\n", " max_new_tokens=100,\n", " eos_token_id=terminators,\n", " do_sample=True,\n", " temperature=0.2,\n", " top_p=1\n", " )\n", "\n", " response = outputs[0][input_ids.shape[-1]:]\n", " return tokenizer.decode(response, skip_special_tokens=True)" ], "metadata": { "id": "R5D4K-P2sPaj" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "text = \\\n", "\"\"\"\n", "Given a set of rules to calculate point, sort the set of words in decreasing order.\n", "When there 2 or more words with same point, sort lexicographically.\n", "\n", "Rules:\n", "- every pair of consecutive consonant gets 5 points\n", "- every pair of consecutive vowel gets 3 points\n", "- add 1 point if there exists exactly 1 'g' in the word\n", "- word less than 5 characters gets 10 points\n", "- word starts with gen gets 100 points\n", "- word ends with ta gets -1000 point\n", "\n", "Words:\n", "- genta\n", "- winata\n", "- hudi\n", "- alham\n", "- aji\n", "- ruochen\n", "\n", "Print only the answer.\n", "\"\"\"\n", "\n", "# Answer:\n", "# - aji 10\n", "# - hudi 10\n", "# - ruochen 5 3\n", "# - alham 5\n", "# - genta 5 1 100 -1000\n", "# - winata -1000" ], "metadata": { "id": "T_tk4hTGsxsR" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "print(get_gemma_response(text))" ], "metadata": { "id": "05OI36v6vGoY" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "print(get_gemma_response(text))" ], "metadata": { "id": "riwXqTc-tmNr" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [], "metadata": { "id": "T72sUG4_vYUa" }, "execution_count": null, "outputs": [] } ] }