Created
February 10, 2026 22:02
-
-
Save Schr3da/0a63917bb51ef0a795f1f1fbb5abfb52 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| /* | |
| Cargo.toml | |
| [package] | |
| name = "candle-app" | |
| version = "0.1.0" | |
| edition = "2024" | |
| [dependencies] | |
| anyhow = "1" | |
| candle-core = "0.9.2" | |
| candle-nn = "0.9.2" | |
| candle-transformers = "0.9.2" | |
| hf-hub = "0.4" | |
| tokenizers = "0.21" | |
| */ | |
| use anyhow::{Context, Result}; | |
| use candle_core::{Device, Tensor}; | |
| use candle_transformers::generation::LogitsProcessor; | |
| use candle_transformers::models::quantized_llama as model; | |
| use hf_hub::api::sync::Api; | |
| use std::io::Write; | |
| use tokenizers::Tokenizer; | |
| fn main() -> Result<()> { | |
| println!("Downloading model from HuggingFace (first run may take a moment)..."); | |
| let api = Api::new()?; | |
| // TinyLlama 1.1B Chat - quantized Q4_K_M GGUF (~637MB) | |
| let model_repo = api.model("TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF".to_string()); | |
| let model_path = model_repo.get("tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf")?; | |
| let tokenizer_repo = api.model("TinyLlama/TinyLlama-1.1B-Chat-v1.0".to_string()); | |
| let tokenizer_path = tokenizer_repo.get("tokenizer.json")?; | |
| println!("Loading model..."); | |
| let tokenizer = | |
| Tokenizer::from_file(tokenizer_path).map_err(|e| anyhow::anyhow!("{e}"))?; | |
| let mut file = std::fs::File::open(&model_path)?; | |
| let gguf = candle_core::quantized::gguf_file::Content::read(&mut file) | |
| .context("Failed to read GGUF file")?; | |
| let mut llm = model::ModelWeights::from_gguf(gguf, &mut file, &Device::Cpu) | |
| .context("Failed to load model weights")?; | |
| println!("Generating response...\n"); | |
| // TinyLlama chat template (Zephyr-style) | |
| let prompt = "<|system|>\nYou are a helpful assistant.</s>\n<|user|>\nTell me a story in 200 words</s>\n<|assistant|>\n"; | |
| let encoding = tokenizer | |
| .encode(prompt, true) | |
| .map_err(|e| anyhow::anyhow!("{e}"))?; | |
| let prompt_tokens = encoding.get_ids(); | |
| let prompt_len = prompt_tokens.len(); | |
| let mut logits_processor = LogitsProcessor::new(42, Some(0.8), Some(0.95)); | |
| let eos_token = 2u32; // </s> for llama-based models | |
| // Forward pass on full prompt | |
| let input = Tensor::new(prompt_tokens, &Device::Cpu)?.unsqueeze(0)?; | |
| let logits = llm.forward(&input, 0)?; | |
| let logits = logits.squeeze(0)?; | |
| // Take last position's logits (handle [seq_len, vocab] vs [vocab] shapes) | |
| let last_logits = if logits.dims().len() == 2 { | |
| logits.get(logits.dim(0)? - 1)? | |
| } else { | |
| logits | |
| }; | |
| let mut next_token = logits_processor.sample(&last_logits)?; | |
| let mut all_tokens = vec![next_token]; | |
| let mut prev_text_len = 0; | |
| // Auto-regressive generation | |
| let max_new_tokens = 500; | |
| for i in 0..max_new_tokens { | |
| if next_token == eos_token { | |
| break; | |
| } | |
| // Decode all tokens and print only the new characters | |
| if let Ok(text) = tokenizer.decode(&all_tokens, true) { | |
| let new_text = &text[prev_text_len..]; | |
| if !new_text.is_empty() { | |
| print!("{new_text}"); | |
| std::io::stdout().flush()?; | |
| } | |
| prev_text_len = text.len(); | |
| } | |
| let input = Tensor::new(&[next_token], &Device::Cpu)?.unsqueeze(0)?; | |
| let logits = llm.forward(&input, prompt_len + i + 1)?; | |
| let logits = logits.squeeze(0)?.squeeze(0)?; | |
| next_token = logits_processor.sample(&logits)?; | |
| if next_token != eos_token { | |
| all_tokens.push(next_token); | |
| } | |
| } | |
| // Print any remaining text | |
| if let Ok(text) = tokenizer.decode(&all_tokens, true) { | |
| let new_text = &text[prev_text_len..]; | |
| if !new_text.is_empty() { | |
| print!("{new_text}"); | |
| } | |
| } | |
| println!("\n"); | |
| Ok(()) | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment