Created
January 6, 2026 01:31
-
-
Save notnotrishi/1da3844064ac19a93791084b9bdcc8db to your computer and use it in GitHub Desktop.
Flask app to convert single image to 3D scene (Gaussian splat) using Apple's SHARP model and render them with GaussianSplats3D viewer by Mark Kellogg
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| Flask application for converting images to 3D using Apple's SHARP model. | |
| Installation: | |
| 1. Clone ml-sharp repository: | |
| git clone https://github.com/apple/ml-sharp | |
| cd ml-sharp | |
| pip install -r requirements.txt | |
| 2. Install Flask and Flask-CORS: | |
| pip install flask flask-cors | |
| 3. Place this script in the ml-sharp directory | |
| 4. Run: python app.py | |
| 5. Navigate to http://localhost:5050 | |
| """ | |
| from flask import Flask, render_template_string, request, jsonify, send_file, send_from_directory | |
| from flask_cors import CORS | |
| from pathlib import Path | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| import io | |
| import logging | |
| app = Flask(__name__) | |
| CORS(app) | |
| @app.after_request | |
| def add_headers(response): | |
| response.headers['Cross-Origin-Opener-Policy'] = 'same-origin' | |
| response.headers['Cross-Origin-Embedder-Policy'] = 'require-corp' | |
| return response | |
| from sharp.models import create_predictor | |
| from sharp.models.params import PredictorParams | |
| from sharp.utils.gaussians import save_ply | |
| from sharp.cli.predict import predict_image | |
| logging.basicConfig(level=logging.INFO) | |
| LOGGER = logging.getLogger(__name__) | |
| predictor = None | |
| device = None | |
| def load_model(): | |
| global predictor, device | |
| if torch.cuda.is_available(): | |
| device = torch.device("cuda") | |
| elif torch.mps.is_available(): | |
| device = torch.device("mps") | |
| else: | |
| device = torch.device("cpu") | |
| LOGGER.info(f"Using device: {device}") | |
| LOGGER.info("Loading SHARP model...") | |
| DEFAULT_MODEL_URL = "https://ml-site.cdn-apple.com/models/sharp/sharp_2572gikvuh.pt" | |
| state_dict = torch.hub.load_state_dict_from_url(DEFAULT_MODEL_URL, progress=True) | |
| predictor = create_predictor(PredictorParams()) | |
| predictor.load_state_dict(state_dict) | |
| predictor.eval() | |
| predictor.to(device) | |
| LOGGER.info("Model loaded") | |
| VIEWER_TEMPLATE = """ | |
| <!DOCTYPE html> | |
| <html> | |
| <head> | |
| <title>3D Viewer</title> | |
| <meta charset="utf-8"> | |
| <meta name="viewport" content="width=device-width, initial-scale=1"> | |
| <style> | |
| body { margin:0; overflow:hidden; font-family: monospace; background:#000; } | |
| #info { position:absolute; top:15px; left:15px; color:white; background:rgba(0,0,0,0.7); padding:15px; border-radius:8px; z-index:1000; font-size:14px; max-width:300px } | |
| #loading { position:absolute; top:50%; left:50%; transform:translate(-50%,-50%); color:white; font-size:18px; text-align:center } | |
| .spinner{ border:4px solid rgba(255,255,255,0.3); border-top:4px solid white; border-radius:50%; width:50px; height:50px; animation:spin 1s linear infinite; margin:0 auto 15px } | |
| @keyframes spin{0%{transform:rotate(0deg)}100%{transform:rotate(360deg)}} | |
| </style> | |
| </head> | |
| <body> | |
| <div id="info"> | |
| <div><strong>GaussianSplats3D Viewer</strong></div> | |
| <div id="status">Loading...</div> | |
| <div style="margin-top:10px; font-size:12px; color:#aaa;">Mouse: Drag to orbit, scroll to zoom<br>File: {{ filename }}</div> | |
| </div> | |
| <div id="loading"><div class="spinner"></div><div>Loading 3D scene...</div></div> | |
| <script type="importmap"> | |
| { | |
| "imports": { | |
| "three": "https://cdn.jsdelivr.net/npm/three@0.160.0/build/three.module.js", | |
| "@mkkellogg/gaussian-splats-3d": "https://cdn.jsdelivr.net/npm/@mkkellogg/gaussian-splats-3d@0.4.7/build/gaussian-splats-3d.module.js" | |
| } | |
| } | |
| </script> | |
| <script type="module"> | |
| import * as GaussianSplats3D from '@mkkellogg/gaussian-splats-3d'; | |
| const statusEl = document.getElementById('status'); | |
| const loadingEl = document.getElementById('loading'); | |
| statusEl.textContent = 'Initializing viewer...'; | |
| const viewer = new GaussianSplats3D.Viewer({ | |
| cameraUp: [0, -1, -0.6], | |
| initialCameraPosition: [0, -2, -5], | |
| initialCameraLookAt: [0, 0, 0], | |
| sphericalHarmonicsDegree: 0, | |
| enableSplatSort: true, | |
| freeIntermediateSplatData: true | |
| }); | |
| viewer.addSplatScene('{{ url_for('static', filename=filename) }}', { | |
| progressiveLoad: true, | |
| showLoadingUI: false, | |
| position: [0,0,0], rotation: [0,0,0,1], scale: [1,1,1] | |
| }) | |
| .then(()=>{ | |
| statusEl.textContent = 'Loaded successfully!'; | |
| loadingEl.style.display='none'; | |
| viewer.start(); | |
| setTimeout(()=>{document.getElementById('info').style.opacity='0.7'},3000); | |
| }) | |
| .catch((err)=>{ | |
| statusEl.textContent = 'Error: '+err.message; statusEl.style.color='#ff6b6b'; console.error('Error loading splat:', err); | |
| }); | |
| </script> | |
| </body> | |
| </html> | |
| """ | |
| HTML_TEMPLATE = """ | |
| <!DOCTYPE html> | |
| <html> | |
| <head> | |
| <title>Image to 3D</title> | |
| <style> | |
| * { margin: 0; padding: 0; box-sizing: border-box; } | |
| body { | |
| font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
| min-height: 100vh; | |
| display: flex; | |
| align-items: center; | |
| justify-content: center; | |
| padding: 20px; | |
| } | |
| .container { | |
| background: white; | |
| border-radius: 20px; | |
| box-shadow: 0 20px 60px rgba(0,0,0,0.3); | |
| max-width: 900px; | |
| width: 100%; | |
| padding: 40px; | |
| } | |
| h1 { font-size: 32px; color: #333; margin-bottom: 10px; } | |
| .subtitle { color: #666; font-size: 14px; margin-bottom: 30px; } | |
| .note { background: #fff3cd; border-left: 4px solid #ffc107; padding: 12px; margin-bottom: 20px; font-size: 13px; color: #856404; } | |
| .drop-zone { border: 3px dashed #667eea; border-radius: 12px; padding: 60px 20px; text-align: center; cursor: pointer; transition: all 0.3s; background: #f8f9ff; margin-bottom: 20px; } | |
| .drop-zone:hover, .drop-zone.dragover { border-color: #764ba2; background: #f0f0ff; transform: scale(1.02); } | |
| .drop-zone-text { color: #667eea; font-size: 18px; font-weight: 600; margin-bottom: 8px; } | |
| .drop-zone-subtext { color: #999; font-size: 14px; } | |
| #preview { max-width: 100%; max-height: 400px; border-radius: 8px; display: none; margin: 20px auto; } | |
| .status { padding: 15px; border-radius: 8px; display: none; margin-top: 15px; font-size: 14px; } | |
| .status.info { background: #e3f2fd; color: #1565c0; display: block; } | |
| .status.success { background: #e8f5e9; color: #2e7d32; display: block; } | |
| .status.error { background: #ffebee; color: #c62828; display: block; } | |
| .result-box { display: none; margin-top: 30px; padding: 20px; background: #f8f9ff; border-radius: 12px; } | |
| .result-box.show { display: block; } | |
| .result-title { font-size: 18px; font-weight: 600; color: #333; margin-bottom: 15px; } | |
| .btn-group { display: flex; gap: 10px; flex-wrap: wrap; } | |
| .btn { background: #667eea; border: none; padding: 12px 24px; border-radius: 8px; cursor: pointer; font-size: 14px; font-weight: 600; color: white; transition: all 0.2s; text-decoration: none; display: inline-block; } | |
| .btn:hover { background: #764ba2; transform: translateY(-2px); box-shadow: 0 4px 12px rgba(0,0,0,0.2); } | |
| .btn-secondary { background: white; color: #667eea; border: 2px solid #667eea; } | |
| .btn-secondary:hover { background: #667eea; color: white; } | |
| </style> | |
| </head> | |
| <body> | |
| <div class="container"> | |
| <div> | |
| <h1>🎨 Image to 3D scene</h1> | |
| <p class="subtitle">Convert any image to a 3D scene using Apple's SHARP</p> | |
| </div> | |
| <div class="note">⚠️ <strong>Note:</strong> SHARP generates basic RGB-only splats for speed. The output quality is limited compared to full training approaches, but processes quickly.</div> | |
| <div class="drop-zone" id="dropZone"> | |
| <div class="drop-zone-text">📸 Drop an image here</div> | |
| <div class="drop-zone-subtext">or click to browse (JPG, PNG)</div> | |
| <input type="file" id="fileInput" accept="image/*" style="display: none;"> | |
| </div> | |
| <img id="preview" alt="Preview"> | |
| <div id="status" class="status"></div> | |
| <div id="resultBox" class="result-box"> | |
| <div class="result-title">✨ 3D Model Ready!</div> | |
| <div class="btn-group"> | |
| <a class="btn" id="downloadBtn" download>⬇️ Download .ply</a> | |
| <a class="btn btn-secondary" id="viewBtn" target="_blank">👁️ View in 3D</a> | |
| </div> | |
| </div> | |
| </div> | |
| <script> | |
| const dropZone = document.getElementById('dropZone'); | |
| const fileInput = document.getElementById('fileInput'); | |
| const preview = document.getElementById('preview'); | |
| const status = document.getElementById('status'); | |
| const resultBox = document.getElementById('resultBox'); | |
| const downloadBtn = document.getElementById('downloadBtn'); | |
| const viewBtn = document.getElementById('viewBtn'); | |
| dropZone.addEventListener('click', () => fileInput.click()); | |
| dropZone.addEventListener('dragover', (e) => { e.preventDefault(); dropZone.classList.add('dragover'); }); | |
| dropZone.addEventListener('dragleave', () => { dropZone.classList.remove('dragover'); }); | |
| dropZone.addEventListener('drop', (e) => { e.preventDefault(); dropZone.classList.remove('dragover'); if (e.dataTransfer.files.length) handleFile(e.dataTransfer.files[0]); }); | |
| fileInput.addEventListener('change', (e) => { if (e.target.files.length) handleFile(e.target.files[0]); }); | |
| function showStatus(message, type) { status.innerHTML = message; status.className = 'status ' + type; if (type !== 'success') resultBox.classList.remove('show'); } | |
| function handleFile(file) { | |
| if (!file.type.startsWith('image/')) { showStatus('❌ Please upload an image file', 'error'); return; } | |
| const reader = new FileReader(); | |
| reader.onload = (e) => { preview.src = e.target.result; preview.style.display = 'block'; uploadImage(file); }; | |
| reader.readAsDataURL(file); | |
| } | |
| function uploadImage(file) { | |
| const formData = new FormData(); formData.append('image', file); | |
| showStatus('⚙️ Processing image...', 'info'); | |
| fetch('/process', { method: 'POST', body: formData }) | |
| .then(response => response.json()) | |
| .then(data => { | |
| if (data.success) { | |
| showStatus('✅ 3D view generated!', 'success'); | |
| downloadBtn.href = '/static/' + data.filename; downloadBtn.download = data.filename; viewBtn.href = '/viewer?file=' + encodeURIComponent(data.filename); resultBox.classList.add('show'); | |
| } else showStatus('❌ Error: ' + data.error, 'error'); | |
| }) | |
| .catch(error => { showStatus('❌ Error: ' + error.message, 'error'); }); | |
| } | |
| </script> | |
| </body> | |
| </html> | |
| """ | |
| @app.route('/') | |
| def index(): | |
| return render_template_string(HTML_TEMPLATE) | |
| @app.route('/viewer') | |
| def viewer(): | |
| filename = request.args.get('file', '') | |
| return render_template_string(VIEWER_TEMPLATE, filename=filename) | |
| @app.route('/process', methods=['POST']) | |
| def process_image(): | |
| try: | |
| if 'image' not in request.files: | |
| return jsonify({'success': False, 'error': 'No image provided'}) | |
| file = request.files['image'] | |
| if file.filename == '': | |
| return jsonify({'success': False, 'error': 'No image selected'}) | |
| image_data = file.read() | |
| image = Image.open(io.BytesIO(image_data)) | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| image_np = np.array(image) | |
| height, width = image_np.shape[:2] | |
| f_px = float(max(width, height)) | |
| LOGGER.info(f"Processing: {width}x{height}, f_px={f_px}") | |
| with torch.no_grad(): | |
| gaussians = predict_image(predictor, image_np, f_px, device) | |
| output_filename = f"splat_{hash(file.filename) & 0x7FFFFFFF}.ply" | |
| output_path = Path('static') / output_filename | |
| output_path.parent.mkdir(exist_ok=True) | |
| save_ply(gaussians, f_px, (height, width), output_path) | |
| LOGGER.info(f"Saved to {output_path}") | |
| return jsonify({'success': True, 'filename': output_filename}) | |
| except Exception as e: | |
| LOGGER.error(f"Error: {str(e)}", exc_info=True) | |
| return jsonify({'success': False, 'error': str(e)}) | |
| if __name__ == '__main__': | |
| LOGGER.info("Starting Flask app...") | |
| load_model() | |
| LOGGER.info("Navigate to http://localhost:5050") | |
| app.run(debug=False, host='0.0.0.0', port=5050) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment