Created
January 26, 2026 14:00
-
-
Save alonsosilvaallende/3e0e00761158e277414a6d13b40eeef0 to your computer and use it in GitHub Desktop.
understanding_next_token_prediction_2026_01.ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "nbformat": 4, | |
| "nbformat_minor": 0, | |
| "metadata": { | |
| "colab": { | |
| "provenance": [], | |
| "gpuType": "T4", | |
| "mount_file_id": "1E_vbYj3d0YGNttOByVnXpslqWlLVXObB", | |
| "authorship_tag": "ABX9TyPea1L3sfnRwgVVyilHvWz6", | |
| "include_colab_link": true | |
| }, | |
| "kernelspec": { | |
| "name": "python3", | |
| "display_name": "Python 3" | |
| }, | |
| "language_info": { | |
| "name": "python" | |
| }, | |
| "accelerator": "GPU" | |
| }, | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "view-in-github", | |
| "colab_type": "text" | |
| }, | |
| "source": [ | |
| "<a href=\"https://colab.research.google.com/gist/alonsosilvaallende/3e0e00761158e277414a6d13b40eeef0/understanding_next_token_prediction_2026_01.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "## Understanding Next Token Prediction" | |
| ], | |
| "metadata": { | |
| "id": "v18e2XKwRKN2" | |
| } | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "Let's look at a [next token prediction app](https://huggingface.co/spaces/alonsosilva/NextTokenPrediction).\n", | |
| "\n", | |
| "This notebooks is mostly taken from [How Transformer LLMs work](https://learn.deeplearning.ai/courses/how-transformer-llms-work)." | |
| ], | |
| "metadata": { | |
| "id": "jlf4We0xQ2-n" | |
| } | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "Let's first load the model and its tokenizer. For that you will first import the classes: AutoModelForCausalLM and AutoTokenizer. When you want to process a sentence, you can apply the tokenizer first and then the model in two separate steps. Or you can create a pipeline object that wraps the two steps and then apply the pipeline to the sentence. You'll explore both approaches in this notebook. This is why you'll also import the pipeline class." | |
| ], | |
| "metadata": { | |
| "id": "GLhdsVeRMY3V" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "import torch\n", | |
| "\n", | |
| "# Auto select device (CUDA > MPS > CPU)\n", | |
| "if torch.cuda.is_available():\n", | |
| " device = torch.device(\"cuda\")\n", | |
| "elif hasattr(torch.backends, \"mps\") and torch.backends.mps.is_available():\n", | |
| " device = torch.device(\"mps\")\n", | |
| "else:\n", | |
| " device = torch.device(\"cpu\")\n", | |
| "assert device == torch.device(\"cuda\"), \"In Runtime, Change runtime type to GPU\"" | |
| ], | |
| "metadata": { | |
| "id": "WQY2q-H6Bwgc" | |
| }, | |
| "execution_count": 1, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "import torch\n", | |
| "from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline\n", | |
| "\n", | |
| "MODEL_ID = \"Qwen/Qwen2.5-0.5B-Instruct\"\n", | |
| "tokenizer = AutoTokenizer.from_pretrained(\n", | |
| " MODEL_ID,\n", | |
| " cache_dir=\"/content/drive/My Drive/\",\n", | |
| ")\n", | |
| "model = AutoModelForCausalLM.from_pretrained(\n", | |
| " MODEL_ID,\n", | |
| " cache_dir=\"/content/drive/My Drive/\",\n", | |
| " device_map=\"cuda\",\n", | |
| " torch_dtype=\"auto\",\n", | |
| " trust_remote_code=True,\n", | |
| ").to(device)" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "T4Pj4vfggsxE", | |
| "outputId": "e64215bc-28d5-45bd-bbb1-a5d62e6133d1" | |
| }, | |
| "execution_count": 2, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "name": "stderr", | |
| "text": [ | |
| "/usr/local/lib/python3.12/dist-packages/huggingface_hub/utils/_auth.py:86: UserWarning: \n", | |
| "Access to the secret `HF_TOKEN` has not been granted on this notebook.\n", | |
| "You will not be requested again.\n", | |
| "Please restart the session if you want to be prompted again.\n", | |
| " warnings.warn(\n", | |
| "`torch_dtype` is deprecated! Use `dtype` instead!\n" | |
| ] | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "Now you can wrap the model and the tokenizer in a [pipeline](https://huggingface.co/docs/transformers/en/main_classes/pipelines#transformers.pipeline) object that has \"text-generation\" as task." | |
| ], | |
| "metadata": { | |
| "id": "qQfsQz9BMjKs" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "# Create a pipeline\n", | |
| "generator = pipeline(\n", | |
| " \"text-generation\",\n", | |
| " model=model,\n", | |
| " tokenizer=tokenizer,\n", | |
| " return_full_text=False, # False means to not include the prompt text in the returned text\n", | |
| " max_new_tokens=50,\n", | |
| " do_sample=False, # no randomness in the generated text\n", | |
| ")" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "kbQi3M_VgLbH", | |
| "outputId": "5c6e68b3-fce9-4e42-d257-b16aac1d6a5a" | |
| }, | |
| "execution_count": 3, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "name": "stderr", | |
| "text": [ | |
| "Device set to use cuda\n", | |
| "The following generation flags are not valid and may be ignored: ['temperature', 'top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.\n" | |
| ] | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "We use the pipeline object (labeled as generator) to generate a response consisting of 50 tokens to the given prompt." | |
| ], | |
| "metadata": { | |
| "id": "JU39CRqMNs_s" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "prompt = \"The capital of France is\"\n", | |
| "output = generator(prompt)\n", | |
| "print(output[0]['generated_text'])" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "-ei8TOywgNvS", | |
| "outputId": "5761d13e-fd41-4d30-c5e6-987ee4eff157" | |
| }, | |
| "execution_count": 4, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "name": "stdout", | |
| "text": [ | |
| " Paris. It was founded in 789 AD by Charlemagne, the last king of the Carolingian dynasty. The city has a long and rich history dating back to ancient times. In fact, it's one of the oldest cities\n" | |
| ] | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "You can also provide a list of messages." | |
| ], | |
| "metadata": { | |
| "id": "PSs3Y5XrgcBp" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "messages = [{\"role\": \"user\", \"content\": \"Hi!\"}]\n", | |
| "output = generator(messages)\n", | |
| "print(output[0]['generated_text'])" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "_HRdylyuhmoV", | |
| "outputId": "58a3c3c3-02fc-4e7f-c2c7-ac9af6019101" | |
| }, | |
| "execution_count": 5, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "name": "stdout", | |
| "text": [ | |
| "Hello! How can I assist you today?\n" | |
| ] | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| " We can print the model to take a look at its architecture." | |
| ], | |
| "metadata": { | |
| "id": "nZic8ldAMtzi" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "model" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "xCRsErxyiBgx", | |
| "outputId": "8d72ec76-612e-4b7b-c1fd-e53eff0e3caa" | |
| }, | |
| "execution_count": 6, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/plain": [ | |
| "Qwen2ForCausalLM(\n", | |
| " (model): Qwen2Model(\n", | |
| " (embed_tokens): Embedding(151936, 896)\n", | |
| " (layers): ModuleList(\n", | |
| " (0-23): 24 x Qwen2DecoderLayer(\n", | |
| " (self_attn): Qwen2Attention(\n", | |
| " (q_proj): Linear(in_features=896, out_features=896, bias=True)\n", | |
| " (k_proj): Linear(in_features=896, out_features=128, bias=True)\n", | |
| " (v_proj): Linear(in_features=896, out_features=128, bias=True)\n", | |
| " (o_proj): Linear(in_features=896, out_features=896, bias=False)\n", | |
| " )\n", | |
| " (mlp): Qwen2MLP(\n", | |
| " (gate_proj): Linear(in_features=896, out_features=4864, bias=False)\n", | |
| " (up_proj): Linear(in_features=896, out_features=4864, bias=False)\n", | |
| " (down_proj): Linear(in_features=4864, out_features=896, bias=False)\n", | |
| " (act_fn): SiLUActivation()\n", | |
| " )\n", | |
| " (input_layernorm): Qwen2RMSNorm((896,), eps=1e-06)\n", | |
| " (post_attention_layernorm): Qwen2RMSNorm((896,), eps=1e-06)\n", | |
| " )\n", | |
| " )\n", | |
| " (norm): Qwen2RMSNorm((896,), eps=1e-06)\n", | |
| " (rotary_emb): Qwen2RotaryEmbedding()\n", | |
| " )\n", | |
| " (lm_head): Linear(in_features=896, out_features=151936, bias=False)\n", | |
| ")" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "execution_count": 6 | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "model.model.embed_tokens" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "nODdrv-Wis_I", | |
| "outputId": "630a5c1d-4153-43de-b18f-19b2bd5ed94f" | |
| }, | |
| "execution_count": 7, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/plain": [ | |
| "Embedding(151936, 896)" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "execution_count": 7 | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "Are these tokens nearby in the embedding space?" | |
| ], | |
| "metadata": { | |
| "id": "WZQUXoi4b2wW" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "words = [\" forest\", \"Hello\", \" hello\", \" Hello\", \"hello\"]" | |
| ], | |
| "metadata": { | |
| "id": "ki2kKHBNcWGc" | |
| }, | |
| "execution_count": 8, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "token_ids = [tokenizer.encode(word)[0] for word in words]\n", | |
| "token_ids" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "viQNtLoacsnh", | |
| "outputId": "a5802b87-650c-467e-b930-443c4ff65070" | |
| }, | |
| "execution_count": 9, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/plain": [ | |
| "[13638, 9707, 23811, 21927, 14990]" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "execution_count": 9 | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "# Get the output of the embed tokens layer\n", | |
| "model_output = model.model.embed_tokens(torch.tensor([token_ids], device=device))" | |
| ], | |
| "metadata": { | |
| "id": "VMyiTfTHjNy3" | |
| }, | |
| "execution_count": 10, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "def my_distance(a, b):\n", | |
| " return ((a-b)**2).sum(axis=0).item()" | |
| ], | |
| "metadata": { | |
| "id": "WhPL-41zkvAA" | |
| }, | |
| "execution_count": 11, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "for i in range(len(model_output[0])):\n", | |
| " for j in range(i+1, len(model_output[0])):\n", | |
| " print(f\"Distance between {words[i]} and {words[j]}: {my_distance(model_output[0][i], model_output[0][j])}\")" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "496PWKYXk90m", | |
| "outputId": "20888fa9-12c2-44d0-9452-12efb63061a9" | |
| }, | |
| "execution_count": 12, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "name": "stdout", | |
| "text": [ | |
| "Distance between forest and Hello: 0.3125\n", | |
| "Distance between forest and hello: 0.30859375\n", | |
| "Distance between forest and Hello: 0.302734375\n", | |
| "Distance between forest and hello: 0.322265625\n", | |
| "Distance between Hello and hello: 0.123046875\n", | |
| "Distance between Hello and Hello: 0.0712890625\n", | |
| "Distance between Hello and hello: 0.11474609375\n", | |
| "Distance between hello and Hello: 0.10302734375\n", | |
| "Distance between hello and hello: 0.0849609375\n", | |
| "Distance between Hello and hello: 0.126953125\n" | |
| ] | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "## Next-token prediction" | |
| ], | |
| "metadata": { | |
| "id": "drKfNzbamWqp" | |
| } | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "Let's give the model a prompt and check the first token it will generate." | |
| ], | |
| "metadata": { | |
| "id": "QUgodUQMJ4bf" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "prompt = \"The capital of France is\"" | |
| ], | |
| "metadata": { | |
| "id": "lRbjkjIgjBmI" | |
| }, | |
| "execution_count": 13, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "We first tokenize the prompt and get the ids of the tokens." | |
| ], | |
| "metadata": { | |
| "id": "w4cnE63HKFby" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "input_ids = tokenizer(prompt, return_tensors=\"pt\").input_ids.to(device)\n", | |
| "input_ids" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "Bl_G0vY1jNCM", | |
| "outputId": "d2fb9da9-7d0f-41ea-860c-043618f09b25" | |
| }, | |
| "execution_count": 14, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/plain": [ | |
| "tensor([[ 785, 6722, 315, 9625, 374]], device='cuda:0')" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "execution_count": 14 | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "We pass the token ids to the transformer block (before the LM head)." | |
| ], | |
| "metadata": { | |
| "id": "_2niwu6WKgSe" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "# Get the output of the model before the lm_head\n", | |
| "model_output = model.model(input_ids)" | |
| ], | |
| "metadata": { | |
| "id": "j8s7in1FjcnX" | |
| }, | |
| "execution_count": 15, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "The transformer block outputs for each token a vector of size 896 (embedding size). Let's check the shape of this output." | |
| ], | |
| "metadata": { | |
| "id": "zwCeboV-K8n4" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "# Get the shape the output the model before the lm_head\n", | |
| "model_output[0].shape" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "U1XwiPxKluSK", | |
| "outputId": "04c1815a-070d-4999-82f5-0f4820aaa9ad" | |
| }, | |
| "execution_count": 16, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/plain": [ | |
| "torch.Size([1, 5, 896])" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "execution_count": 16 | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "The first number represents the batch size, which is 1 in this case since we have one prompt. The second number 5 represents the number of tokens. And finally 896 represents the embedding size (the size of the vector that corresponds to each token).\n", | |
| "\n", | |
| "Let's now get the output of the LM head." | |
| ], | |
| "metadata": { | |
| "id": "UqvhVK80LbVJ" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "# Get the output of the lm_head\n", | |
| "lm_head_output = model.lm_head(model_output[0])" | |
| ], | |
| "metadata": { | |
| "id": "nCfW2ooelz7H" | |
| }, | |
| "execution_count": 17, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "lm_head_output.shape" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "RWpH1KfeOogz", | |
| "outputId": "25d4e8e9-96c5-4d90-e6c9-43b5f07b3e00" | |
| }, | |
| "execution_count": 18, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/plain": [ | |
| "torch.Size([1, 5, 151936])" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "execution_count": 18 | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "[https://x.com/gabriberton/status/2007268853072720266](https://x.com/gabriberton/status/2007268853072720266)\n", | |
| "\n", | |
| "[https://x.com/gabriberton/status/2007327212438204893?s=20](https://x.com/gabriberton/status/2007327212438204893?s=20)" | |
| ], | |
| "metadata": { | |
| "id": "UbpSDRfyP4ec" | |
| } | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "The LM head outputs for each token in the input prompt, a vector of size 151936 (padded vocabulary size). So there are 5 vectors, each of size 151936. Each vector can be mapped to a probability distribution, that shows the probability for each token in the vocabulary to come after the given token in the input prompt.\n", | |
| "\n", | |
| "Since we're interested in generating the output token that comes after the last token in the input prompt (\"is\"), we'll focus on the last vector. So in the next cell, lm_head_output[0,-1] is a vector of size 151936 from which you can generate the token that comes after (\"is\"). You can do that by finding the id of the token that corresponds to the highest value in the vector lm_head_output[0,-1] (using argmax(-1), -1 means across the last axis here)." | |
| ], | |
| "metadata": { | |
| "id": "1dd8N9CHL7QX" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "token_id = lm_head_output[0,-1].argmax(-1)\n", | |
| "token_id" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "t5b3SOj3l86w", | |
| "outputId": "3f84921d-b340-4432-9496-2915a3b5fbe5" | |
| }, | |
| "execution_count": 19, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/plain": [ | |
| "tensor(12095, device='cuda:0')" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "execution_count": 19 | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "Finally, let's decode the returned token id." | |
| ], | |
| "metadata": { | |
| "id": "DBRzC8dZLxjb" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "tokenizer.decode(token_id)" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/", | |
| "height": 36 | |
| }, | |
| "id": "k5wSp8wVmE1O", | |
| "outputId": "5e092825-97f9-45e9-ac3d-8a629883a914" | |
| }, | |
| "execution_count": 20, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/plain": [ | |
| "' Paris'" | |
| ], | |
| "application/vnd.google.colaboratory.intrinsic+json": { | |
| "type": "string" | |
| } | |
| }, | |
| "metadata": {}, | |
| "execution_count": 20 | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "prompt = \"The capital of France is\"\n", | |
| "input_ids = tokenizer(prompt, return_tensors=\"pt\").input_ids.to(device)\n", | |
| "outputs = model(input_ids)\n", | |
| "logits = outputs.logits[:, -1, :] # Focus on the last token's logits\n", | |
| "token_ids = [token_id for token_id in range(len(tokenizer))]\n", | |
| "tokens = [tokenizer.decode(token_id) for token_id in token_ids]\n", | |
| "logits_list = logits[0][:len(tokenizer)].tolist()\n", | |
| "probs = torch.softmax(logits, dim=-1)\n", | |
| "\n", | |
| "import pandas as pd\n", | |
| "df = pd.DataFrame({\"token_id\": token_ids, \"token\": tokens, \"probability\": probs[0][:len(tokenizer)].tolist()})\n", | |
| "df.sort_values(by=\"probability\", ascending=False).head(10)" | |
| ], | |
| "metadata": { | |
| "id": "JrFhBmSJJrEv", | |
| "colab": { | |
| "base_uri": "https://localhost:8080/", | |
| "height": 362 | |
| }, | |
| "outputId": "f8177f0c-7c73-43f1-eed0-aeb103346294" | |
| }, | |
| "execution_count": 21, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/plain": [ | |
| " token_id token probability\n", | |
| "12095 12095 Paris 0.291016\n", | |
| "32671 32671 ______ 0.121582\n", | |
| "510 510 :\\n 0.069336\n", | |
| "1447 1447 :\\n\\n 0.057373\n", | |
| "1304 1304 __ 0.053955\n", | |
| "7407 7407 located 0.047607\n", | |
| "30743 30743 ____ 0.041992\n", | |
| "279 279 the 0.034912\n", | |
| "320 320 ( 0.025513\n", | |
| "508 508 [ 0.022461" | |
| ], | |
| "text/html": [ | |
| "\n", | |
| " <div id=\"df-701740fd-8a7f-46a1-abf8-38b0ac26d3a8\" class=\"colab-df-container\">\n", | |
| " <div>\n", | |
| "<style scoped>\n", | |
| " .dataframe tbody tr th:only-of-type {\n", | |
| " vertical-align: middle;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe tbody tr th {\n", | |
| " vertical-align: top;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe thead th {\n", | |
| " text-align: right;\n", | |
| " }\n", | |
| "</style>\n", | |
| "<table border=\"1\" class=\"dataframe\">\n", | |
| " <thead>\n", | |
| " <tr style=\"text-align: right;\">\n", | |
| " <th></th>\n", | |
| " <th>token_id</th>\n", | |
| " <th>token</th>\n", | |
| " <th>probability</th>\n", | |
| " </tr>\n", | |
| " </thead>\n", | |
| " <tbody>\n", | |
| " <tr>\n", | |
| " <th>12095</th>\n", | |
| " <td>12095</td>\n", | |
| " <td>Paris</td>\n", | |
| " <td>0.291016</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>32671</th>\n", | |
| " <td>32671</td>\n", | |
| " <td>______</td>\n", | |
| " <td>0.121582</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>510</th>\n", | |
| " <td>510</td>\n", | |
| " <td>:\\n</td>\n", | |
| " <td>0.069336</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>1447</th>\n", | |
| " <td>1447</td>\n", | |
| " <td>:\\n\\n</td>\n", | |
| " <td>0.057373</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>1304</th>\n", | |
| " <td>1304</td>\n", | |
| " <td>__</td>\n", | |
| " <td>0.053955</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>7407</th>\n", | |
| " <td>7407</td>\n", | |
| " <td>located</td>\n", | |
| " <td>0.047607</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>30743</th>\n", | |
| " <td>30743</td>\n", | |
| " <td>____</td>\n", | |
| " <td>0.041992</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>279</th>\n", | |
| " <td>279</td>\n", | |
| " <td>the</td>\n", | |
| " <td>0.034912</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>320</th>\n", | |
| " <td>320</td>\n", | |
| " <td>(</td>\n", | |
| " <td>0.025513</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>508</th>\n", | |
| " <td>508</td>\n", | |
| " <td>[</td>\n", | |
| " <td>0.022461</td>\n", | |
| " </tr>\n", | |
| " </tbody>\n", | |
| "</table>\n", | |
| "</div>\n", | |
| " <div class=\"colab-df-buttons\">\n", | |
| "\n", | |
| " <div class=\"colab-df-container\">\n", | |
| " <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-701740fd-8a7f-46a1-abf8-38b0ac26d3a8')\"\n", | |
| " title=\"Convert this dataframe to an interactive table.\"\n", | |
| " style=\"display:none;\">\n", | |
| "\n", | |
| " <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\" viewBox=\"0 -960 960 960\">\n", | |
| " <path d=\"M120-120v-720h720v720H120Zm60-500h600v-160H180v160Zm220 220h160v-160H400v160Zm0 220h160v-160H400v160ZM180-400h160v-160H180v160Zm440 0h160v-160H620v160ZM180-180h160v-160H180v160Zm440 0h160v-160H620v160Z\"/>\n", | |
| " </svg>\n", | |
| " </button>\n", | |
| "\n", | |
| " <style>\n", | |
| " .colab-df-container {\n", | |
| " display:flex;\n", | |
| " gap: 12px;\n", | |
| " }\n", | |
| "\n", | |
| " .colab-df-convert {\n", | |
| " background-color: #E8F0FE;\n", | |
| " border: none;\n", | |
| " border-radius: 50%;\n", | |
| " cursor: pointer;\n", | |
| " display: none;\n", | |
| " fill: #1967D2;\n", | |
| " height: 32px;\n", | |
| " padding: 0 0 0 0;\n", | |
| " width: 32px;\n", | |
| " }\n", | |
| "\n", | |
| " .colab-df-convert:hover {\n", | |
| " background-color: #E2EBFA;\n", | |
| " box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n", | |
| " fill: #174EA6;\n", | |
| " }\n", | |
| "\n", | |
| " .colab-df-buttons div {\n", | |
| " margin-bottom: 4px;\n", | |
| " }\n", | |
| "\n", | |
| " [theme=dark] .colab-df-convert {\n", | |
| " background-color: #3B4455;\n", | |
| " fill: #D2E3FC;\n", | |
| " }\n", | |
| "\n", | |
| " [theme=dark] .colab-df-convert:hover {\n", | |
| " background-color: #434B5C;\n", | |
| " box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n", | |
| " filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n", | |
| " fill: #FFFFFF;\n", | |
| " }\n", | |
| " </style>\n", | |
| "\n", | |
| " <script>\n", | |
| " const buttonEl =\n", | |
| " document.querySelector('#df-701740fd-8a7f-46a1-abf8-38b0ac26d3a8 button.colab-df-convert');\n", | |
| " buttonEl.style.display =\n", | |
| " google.colab.kernel.accessAllowed ? 'block' : 'none';\n", | |
| "\n", | |
| " async function convertToInteractive(key) {\n", | |
| " const element = document.querySelector('#df-701740fd-8a7f-46a1-abf8-38b0ac26d3a8');\n", | |
| " const dataTable =\n", | |
| " await google.colab.kernel.invokeFunction('convertToInteractive',\n", | |
| " [key], {});\n", | |
| " if (!dataTable) return;\n", | |
| "\n", | |
| " const docLinkHtml = 'Like what you see? Visit the ' +\n", | |
| " '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n", | |
| " + ' to learn more about interactive tables.';\n", | |
| " element.innerHTML = '';\n", | |
| " dataTable['output_type'] = 'display_data';\n", | |
| " await google.colab.output.renderOutput(dataTable, element);\n", | |
| " const docLink = document.createElement('div');\n", | |
| " docLink.innerHTML = docLinkHtml;\n", | |
| " element.appendChild(docLink);\n", | |
| " }\n", | |
| " </script>\n", | |
| " </div>\n", | |
| "\n", | |
| "\n", | |
| " </div>\n", | |
| " </div>\n" | |
| ], | |
| "application/vnd.google.colaboratory.intrinsic+json": { | |
| "type": "dataframe", | |
| "summary": "{\n \"name\": \"df\",\n \"rows\": 10,\n \"fields\": [\n {\n \"column\": \"token_id\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 12720,\n \"min\": 279,\n \"max\": 32671,\n \"num_unique_values\": 10,\n \"samples\": [\n 320,\n 32671,\n 7407\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"token\",\n \"properties\": {\n \"dtype\": \"string\",\n \"num_unique_values\": 10,\n \"samples\": [\n \" (\",\n \" ______\",\n \" located\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"probability\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.08044996587604922,\n \"min\": 0.0224609375,\n \"max\": 0.291015625,\n \"num_unique_values\": 10,\n \"samples\": [\n 0.0255126953125,\n 0.12158203125,\n 0.047607421875\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n }\n ]\n}" | |
| } | |
| }, | |
| "metadata": {}, | |
| "execution_count": 21 | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "df[df['token'] == \" Brussels\"]" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/", | |
| "height": 80 | |
| }, | |
| "id": "G9LBF5BTqO8t", | |
| "outputId": "2653d543-7430-40ed-f55e-dd8b28c29329" | |
| }, | |
| "execution_count": 22, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/plain": [ | |
| " token_id token probability\n", | |
| "37169 37169 Brussels 0.000362" | |
| ], | |
| "text/html": [ | |
| "\n", | |
| " <div id=\"df-1f7b3f6e-10e4-4d54-bf93-dd2a8af78a83\" class=\"colab-df-container\">\n", | |
| " <div>\n", | |
| "<style scoped>\n", | |
| " .dataframe tbody tr th:only-of-type {\n", | |
| " vertical-align: middle;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe tbody tr th {\n", | |
| " vertical-align: top;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe thead th {\n", | |
| " text-align: right;\n", | |
| " }\n", | |
| "</style>\n", | |
| "<table border=\"1\" class=\"dataframe\">\n", | |
| " <thead>\n", | |
| " <tr style=\"text-align: right;\">\n", | |
| " <th></th>\n", | |
| " <th>token_id</th>\n", | |
| " <th>token</th>\n", | |
| " <th>probability</th>\n", | |
| " </tr>\n", | |
| " </thead>\n", | |
| " <tbody>\n", | |
| " <tr>\n", | |
| " <th>37169</th>\n", | |
| " <td>37169</td>\n", | |
| " <td>Brussels</td>\n", | |
| " <td>0.000362</td>\n", | |
| " </tr>\n", | |
| " </tbody>\n", | |
| "</table>\n", | |
| "</div>\n", | |
| " <div class=\"colab-df-buttons\">\n", | |
| "\n", | |
| " <div class=\"colab-df-container\">\n", | |
| " <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-1f7b3f6e-10e4-4d54-bf93-dd2a8af78a83')\"\n", | |
| " title=\"Convert this dataframe to an interactive table.\"\n", | |
| " style=\"display:none;\">\n", | |
| "\n", | |
| " <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\" viewBox=\"0 -960 960 960\">\n", | |
| " <path d=\"M120-120v-720h720v720H120Zm60-500h600v-160H180v160Zm220 220h160v-160H400v160Zm0 220h160v-160H400v160ZM180-400h160v-160H180v160Zm440 0h160v-160H620v160ZM180-180h160v-160H180v160Zm440 0h160v-160H620v160Z\"/>\n", | |
| " </svg>\n", | |
| " </button>\n", | |
| "\n", | |
| " <style>\n", | |
| " .colab-df-container {\n", | |
| " display:flex;\n", | |
| " gap: 12px;\n", | |
| " }\n", | |
| "\n", | |
| " .colab-df-convert {\n", | |
| " background-color: #E8F0FE;\n", | |
| " border: none;\n", | |
| " border-radius: 50%;\n", | |
| " cursor: pointer;\n", | |
| " display: none;\n", | |
| " fill: #1967D2;\n", | |
| " height: 32px;\n", | |
| " padding: 0 0 0 0;\n", | |
| " width: 32px;\n", | |
| " }\n", | |
| "\n", | |
| " .colab-df-convert:hover {\n", | |
| " background-color: #E2EBFA;\n", | |
| " box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n", | |
| " fill: #174EA6;\n", | |
| " }\n", | |
| "\n", | |
| " .colab-df-buttons div {\n", | |
| " margin-bottom: 4px;\n", | |
| " }\n", | |
| "\n", | |
| " [theme=dark] .colab-df-convert {\n", | |
| " background-color: #3B4455;\n", | |
| " fill: #D2E3FC;\n", | |
| " }\n", | |
| "\n", | |
| " [theme=dark] .colab-df-convert:hover {\n", | |
| " background-color: #434B5C;\n", | |
| " box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n", | |
| " filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n", | |
| " fill: #FFFFFF;\n", | |
| " }\n", | |
| " </style>\n", | |
| "\n", | |
| " <script>\n", | |
| " const buttonEl =\n", | |
| " document.querySelector('#df-1f7b3f6e-10e4-4d54-bf93-dd2a8af78a83 button.colab-df-convert');\n", | |
| " buttonEl.style.display =\n", | |
| " google.colab.kernel.accessAllowed ? 'block' : 'none';\n", | |
| "\n", | |
| " async function convertToInteractive(key) {\n", | |
| " const element = document.querySelector('#df-1f7b3f6e-10e4-4d54-bf93-dd2a8af78a83');\n", | |
| " const dataTable =\n", | |
| " await google.colab.kernel.invokeFunction('convertToInteractive',\n", | |
| " [key], {});\n", | |
| " if (!dataTable) return;\n", | |
| "\n", | |
| " const docLinkHtml = 'Like what you see? Visit the ' +\n", | |
| " '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n", | |
| " + ' to learn more about interactive tables.';\n", | |
| " element.innerHTML = '';\n", | |
| " dataTable['output_type'] = 'display_data';\n", | |
| " await google.colab.output.renderOutput(dataTable, element);\n", | |
| " const docLink = document.createElement('div');\n", | |
| " docLink.innerHTML = docLinkHtml;\n", | |
| " element.appendChild(docLink);\n", | |
| " }\n", | |
| " </script>\n", | |
| " </div>\n", | |
| "\n", | |
| "\n", | |
| " </div>\n", | |
| " </div>\n" | |
| ], | |
| "application/vnd.google.colaboratory.intrinsic+json": { | |
| "type": "dataframe", | |
| "summary": "{\n \"name\": \"df[df['token'] == \\\" Brussels\\\"]\",\n \"rows\": 1,\n \"fields\": [\n {\n \"column\": \"token_id\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": null,\n \"min\": 37169,\n \"max\": 37169,\n \"num_unique_values\": 1,\n \"samples\": [\n 37169\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"token\",\n \"properties\": {\n \"dtype\": \"string\",\n \"num_unique_values\": 1,\n \"samples\": [\n \" Brussels\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"probability\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": null,\n \"min\": 0.000362396240234375,\n \"max\": 0.000362396240234375,\n \"num_unique_values\": 1,\n \"samples\": [\n 0.000362396240234375\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n }\n ]\n}" | |
| } | |
| }, | |
| "metadata": {}, | |
| "execution_count": 22 | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "sampled_tokens = []\n", | |
| "for seed in range(100):\n", | |
| " g_cuda = torch.Generator(device='cuda').manual_seed(seed)\n", | |
| " sampled_token = torch.multinomial(probs, num_samples=1, generator=g_cuda)\n", | |
| " print(f\"{seed}: {tokenizer.decode(sampled_token[0])}\")\n", | |
| " sampled_tokens.append(sampled_token[0])" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "onyYX3SUoy6b", | |
| "outputId": "f7c00acb-0516-4759-c99d-90147524ca46" | |
| }, | |
| "execution_count": 23, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "name": "stdout", | |
| "text": [ | |
| "0: Paris\n", | |
| "1: Paris\n", | |
| "2: the\n", | |
| "3: \n", | |
| "\n", | |
| "4: commonly\n", | |
| "5: Î\n", | |
| "6: Paris\n", | |
| "7: :\n", | |
| "\n", | |
| "8: Paris\n", | |
| "9: Paris\n", | |
| "10: Paris\n", | |
| "11: the\n", | |
| "12: :\n", | |
| "\n", | |
| "13: Paris\n", | |
| "14: ______\n", | |
| "15: ______\n", | |
| "16: \n", | |
| "\n", | |
| "\n", | |
| "17: what\n", | |
| "18: currently\n", | |
| "19: [\n", | |
| "20: Rome\n", | |
| "21: ______\n", | |
| "22: ______\n", | |
| "23: Paris\n", | |
| "24: ?\n", | |
| "\n", | |
| "\n", | |
| "25: :\n", | |
| "\n", | |
| "\n", | |
| "26: Paris\n", | |
| "27: situated\n", | |
| "28: .\n", | |
| "\n", | |
| "\n", | |
| "29: the\n", | |
| "30: Paris\n", | |
| "31: ______\n", | |
| "32: called\n", | |
| "33: the\n", | |
| "34: ______\n", | |
| "35: ______\n", | |
| "36: \n", | |
| "\n", | |
| "37: Paris\n", | |
| "38: generally\n", | |
| "39: :\n", | |
| "\n", | |
| "40: Paris\n", | |
| "41: ....\n", | |
| "\n", | |
| "\n", | |
| "42: ____\n", | |
| "43: Paris\n", | |
| "44: (\n", | |
| "45: ___\n", | |
| "46: located\n", | |
| "47: located\n", | |
| "48: Paris\n", | |
| "49: \n", | |
| "\n", | |
| "50: can\n", | |
| "51: Paris\n", | |
| "52: Paris\n", | |
| "53: currently\n", | |
| "54: [\n", | |
| "55: Paris\n", | |
| "56: __\n", | |
| "57: \n", | |
| "\n", | |
| "58: Paris\n", | |
| "59: :\n", | |
| "\n", | |
| "60: Paris\n", | |
| "61: Paris\n", | |
| "62: ____\n", | |
| "63: :\n", | |
| "\n", | |
| "\n", | |
| "64: (\n", | |
| "65: located\n", | |
| "66: known\n", | |
| "67: ______\n", | |
| "68: _____\n", | |
| "69: located\n", | |
| "70: Paris\n", | |
| "71: :\n", | |
| "\n", | |
| "72: situated\n", | |
| "73: Paris\n", | |
| "74: :\n", | |
| "\n", | |
| "\n", | |
| "75: _____\n", | |
| "76: \n", | |
| "\n", | |
| "77: typically\n", | |
| "78: not\n", | |
| "79: located\n", | |
| "80: ______\n", | |
| "81: decided\n", | |
| "82: :\n", | |
| "\n", | |
| "83: Paris\n", | |
| "84: Paris\n", | |
| "85: :\n", | |
| "\n", | |
| "86: Paris\n", | |
| "87: called\n", | |
| "88: light\n", | |
| "89: :\n", | |
| "\n", | |
| "90: a\n", | |
| "91: \n", | |
| "\n", | |
| "92: [\n", | |
| "93: Paris\n", | |
| "94: __\n", | |
| "95: _____\n", | |
| "96: ______\n", | |
| "97: Paris\n", | |
| "98: ______\n", | |
| "99: The\n" | |
| ] | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "seed = 20\n", | |
| "g_cuda = torch.Generator(device='cuda').manual_seed(seed)\n", | |
| "sampled_token = torch.multinomial(probs, num_samples=1, generator=g_cuda)\n", | |
| "print(f\"{seed}: {tokenizer.decode(sampled_token[0])}\")" | |
| ], | |
| "metadata": { | |
| "id": "y2kEskUgpAv1", | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "outputId": "5ddcb56d-7e2f-41d1-87d6-bfb8cc5f2ab3" | |
| }, | |
| "execution_count": 24, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "name": "stdout", | |
| "text": [ | |
| "20: Rome\n" | |
| ] | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "## Sequential next-token prediction (autoregressive text generation)" | |
| ], | |
| "metadata": { | |
| "id": "kWt8a7lflHfV" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "message = [{\"role\": \"user\", \"content\": \"Why is the sky blue?\"}]" | |
| ], | |
| "metadata": { | |
| "id": "C6_3DZIcAR3L" | |
| }, | |
| "execution_count": 25, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "prompt = tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True)\n", | |
| "print(prompt)" | |
| ], | |
| "metadata": { | |
| "id": "D4FVihnkAZy_", | |
| "outputId": "c1758719-062e-4db1-9a7a-487804ebefd9", | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| } | |
| }, | |
| "execution_count": 26, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "name": "stdout", | |
| "text": [ | |
| "<|im_start|>system\n", | |
| "You are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n", | |
| "<|im_start|>user\n", | |
| "Why is the sky blue?<|im_end|>\n", | |
| "<|im_start|>assistant\n", | |
| "\n" | |
| ] | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "input_ids = tokenizer(prompt, return_tensors=\"pt\").input_ids" | |
| ], | |
| "metadata": { | |
| "id": "MLxb3RWrd46C" | |
| }, | |
| "execution_count": 27, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "Multinomial sampling" | |
| ], | |
| "metadata": { | |
| "id": "NbQuVaU_k5qv" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "current_response = input_ids.to(device)\n", | |
| "for generated_token in range(50):\n", | |
| " outputs = model(current_response)\n", | |
| " logits = outputs.logits[:, -1, :] # Focus on the last token's logits\n", | |
| " probs = torch.softmax(logits, dim=-1)\n", | |
| " sampled_token = torch.multinomial(probs, num_samples=1)\n", | |
| " print(tokenizer.decode(sampled_token[0]), end=\"\")\n", | |
| " current_response = torch.cat([current_response, sampled_token], dim=-1)" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "PBAvDG-VeHH0", | |
| "outputId": "cb7ac56b-b4e9-4593-fee7-3a421e0cd5d7" | |
| }, | |
| "execution_count": 28, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "name": "stdout", | |
| "text": [ | |
| "The sky appears blue due to various factors that scientists attribute to the reasons for its vivid colors. One explanation is the presence of dissolved salts in the atmosphere that emit short-wavelength blue light, which is absorbed by water vapor and doesn’t reach the observer" | |
| ] | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "Greedy search" | |
| ], | |
| "metadata": { | |
| "id": "wweWEI3ClDBp" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "current_response = input_ids.to(device)\n", | |
| "for generated_token in range(50):\n", | |
| " outputs = model(current_response)\n", | |
| " logits = outputs.logits[:, -1, :] # Focus on the last token's logits\n", | |
| " probs = torch.softmax(logits, dim=-1)\n", | |
| " sampled_token = torch.argmax(probs, dim=-1, keepdim=True)\n", | |
| " print(tokenizer.decode(sampled_token[0]), end=\"\")\n", | |
| " current_response = torch.cat([current_response, sampled_token], dim=-1)" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "Z-ts0pmzi5B6", | |
| "outputId": "02c24d04-fb3a-4fbf-a845-94b600c8839d" | |
| }, | |
| "execution_count": 29, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "name": "stdout", | |
| "text": [ | |
| "The sky is blue because of the scattering of sunlight by tiny particles in the Earth's atmosphere. When sunlight enters the Earth's atmosphere, it is scattered in all directions by these tiny particles, such as water droplets, ice crystals, and tiny dust" | |
| ] | |
| } | |
| ] | |
| } | |
| ] | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment