Skip to content

Instantly share code, notes, and snippets.

@criminact
Created April 20, 2024 21:35
Show Gist options
  • Select an option

  • Save criminact/ba0472e79303bf282752f597f492cbb5 to your computer and use it in GitHub Desktop.

Select an option

Save criminact/ba0472e79303bf282752f597f492cbb5 to your computer and use it in GitHub Desktop.
llama3-8b-instruct-demo.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"gpuType": "T4",
"authorship_tag": "ABX9TyPl9+5rnQapXcUOPw1kqvOG",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/criminact/ba0472e79303bf282752f597f492cbb5/llama3-8b-instruct-demo.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"source": [
"!pip3 install -U accelerate transformers"
],
"metadata": {
"id": "zYqz766pNIVl"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "zpTfB5fFMwJF"
},
"outputs": [],
"source": [
"import transformers\n",
"import torch\n",
"\n",
"model_id = \"meta-llama/Meta-Llama-3-8B-Instruct\"\n",
"\n",
"pipeline = transformers.pipeline(\n",
" \"text-generation\",\n",
" model=model_id,\n",
" model_kwargs={\"torch_dtype\": torch.float16},\n",
" device=\"auto\",\n",
" token = \"<HUGGINGFACE_READ_TOKEN>\"\n",
")\n",
"\n",
"messages = [\n",
" {\"role\": \"system\", \"content\": \"You are a pirate chatbot who always responds in pirate speak!\"},\n",
" {\"role\": \"user\", \"content\": \"Who are you?\"},\n",
"]\n",
"\n",
"prompt = pipeline.tokenizer.apply_chat_template(\n",
" messages,\n",
" tokenize=False,\n",
" add_generation_prompt=True\n",
")\n",
"\n",
"terminators = [\n",
" pipeline.tokenizer.eos_token_id,\n",
" pipeline.tokenizer.convert_tokens_to_ids(\"<|eot_id|>\")\n",
"]\n",
"\n",
"outputs = pipeline(\n",
" prompt,\n",
" max_new_tokens=256,\n",
" eos_token_id=terminators,\n",
" do_sample=True,\n",
" temperature=0.6,\n",
" top_p=0.9,\n",
")\n",
"print(outputs[0][\"generated_text\"][len(prompt):])"
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment