Skip to content

Instantly share code, notes, and snippets.

@notnotrishi
Created January 6, 2026 01:29
Show Gist options
  • Select an option

  • Save notnotrishi/8d612a06640f84cb3245046c0b885b5c to your computer and use it in GitHub Desktop.

Select an option

Save notnotrishi/8d612a06640f84cb3245046c0b885b5c to your computer and use it in GitHub Desktop.
Flask app to convert single image to 3D scene (Gaussian splat) using Apple's SHARP model and render them with Spark viewer
"""
Flask application for converting images to 3D using Apple's SHARP model.
Installation:
1. Clone ml-sharp repository:
git clone https://github.com/apple/ml-sharp
cd ml-sharp
pip install -r requirements.txt
2. Install Flask and Flask-CORS:
pip install flask flask-cors
3. Place this script in the ml-sharp directory
4. Run: python main.py
5. Navigate to http://localhost:5050
"""
from flask import Flask, render_template_string, request, jsonify, send_file, send_from_directory
from flask_cors import CORS
from pathlib import Path
import torch
import numpy as np
from PIL import Image
import io
import logging
app = Flask(__name__)
CORS(app)
@app.after_request
def add_headers(response):
response.headers['Cross-Origin-Opener-Policy'] = 'same-origin'
response.headers['Cross-Origin-Embedder-Policy'] = 'require-corp'
return response
from sharp.models import create_predictor
from sharp.models.params import PredictorParams
from sharp.utils.gaussians import save_ply
from sharp.cli.predict import predict_image
logging.basicConfig(level=logging.INFO)
LOGGER = logging.getLogger(__name__)
predictor = None
device = None
def load_model():
global predictor, device
if torch.cuda.is_available():
device = torch.device("cuda")
elif torch.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")
LOGGER.info(f"Using device: {device}")
LOGGER.info("Loading SHARP model...")
DEFAULT_MODEL_URL = "https://ml-site.cdn-apple.com/models/sharp/sharp_2572gikvuh.pt"
state_dict = torch.hub.load_state_dict_from_url(DEFAULT_MODEL_URL, progress=True)
predictor = create_predictor(PredictorParams())
predictor.load_state_dict(state_dict)
predictor.eval()
predictor.to(device)
LOGGER.info("Model loaded")
VIEWER_TEMPLATE = """
<!DOCTYPE html>
<html>
<head>
<title>3D Viewer</title>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1">
<style>
body { margin:0; overflow:hidden; font-family: monospace; background:#000; }
#info { position:absolute; top:15px; left:15px; color:white; background:rgba(0,0,0,0.7); padding:15px; border-radius:8px; z-index:1000; font-size:14px; max-width:300px }
#loading { position:absolute; top:50%; left:50%; transform:translate(-50%,-50%); color:white; font-size:18px; text-align:center }
.spinner{ border:4px solid rgba(255,255,255,0.3); border-top:4px solid white; border-radius:50%; width:50px; height:50px; animation:spin 1s linear infinite; margin:0 auto 15px }
@keyframes spin{0%{transform:rotate(0deg)}100%{transform:rotate(360deg)}}
</style>
</head>
<body>
<div id="info">
<div><strong>Spark Viewer</strong></div>
<div id="status">Loading...</div>
<div style="margin-top:10px; font-size:12px; color:#aaa;">Mouse: Drag to orbit, scroll to zoom<br>File: {{ filename }}</div>
</div>
<div id="loading"><div class="spinner"></div><div>Loading 3D scene...</div></div>
<script type="importmap">
{
"imports": {
"three": "https://cdnjs.cloudflare.com/ajax/libs/three.js/0.178.0/three.module.js",
"@sparkjsdev/spark": "https://sparkjs.dev/releases/spark/0.1.10/spark.module.js"
}
}
</script>
<script type="module">
import * as THREE from 'three';
import { SplatMesh, SparkRenderer } from '@sparkjsdev/spark';
const statusEl = document.getElementById('status');
const loadingEl = document.getElementById('loading');
statusEl.textContent = 'Initializing viewer...';
// Create scene
const scene = new THREE.Scene();
// Create camera
const camera = new THREE.PerspectiveCamera(
60,
window.innerWidth / window.innerHeight,
0.1,
1000
);
camera.position.set(0, -2, -5);
camera.lookAt(0, 0, 0);
// Create renderer
const renderer = new THREE.WebGLRenderer({ antialias: false });
renderer.setSize(window.innerWidth, window.innerHeight);
document.body.appendChild(renderer.domElement);
// Create SparkRenderer
const spark = new SparkRenderer({ renderer });
scene.add(spark);
// Load splat
const splatURL = '{{ url_for('static', filename=filename) }}';
const splatMesh = new SplatMesh({ url: splatURL });
splatMesh.quaternion.set(1, 0, 0, 0);
splatMesh.position.set(0, 0, 0);
scene.add(splatMesh);
// Simple orbit controls (mouse drag)
let isDragging = false;
let previousMousePosition = { x: 0, y: 0 };
const rotationSpeed = 0.005;
renderer.domElement.addEventListener('mousedown', (e) => {
isDragging = true;
previousMousePosition = { x: e.clientX, y: e.clientY };
});
renderer.domElement.addEventListener('mousemove', (e) => {
if (isDragging) {
const deltaX = e.clientX - previousMousePosition.x;
const deltaY = e.clientY - previousMousePosition.y;
camera.position.applyAxisAngle(new THREE.Vector3(0, 1, 0), -deltaX * rotationSpeed);
const right = new THREE.Vector3(1, 0, 0).applyQuaternion(camera.quaternion);
camera.position.applyAxisAngle(right, -deltaY * rotationSpeed);
camera.lookAt(0, 0, 0);
previousMousePosition = { x: e.clientX, y: e.clientY };
}
});
renderer.domElement.addEventListener('mouseup', () => {
isDragging = false;
});
renderer.domElement.addEventListener('mouseleave', () => {
isDragging = false;
});
// Zoom with scroll
renderer.domElement.addEventListener('wheel', (e) => {
e.preventDefault();
const zoomSpeed = 0.1;
const direction = camera.position.clone().normalize();
camera.position.addScaledVector(direction, e.deltaY * zoomSpeed * 0.01);
});
// Handle window resize
window.addEventListener('resize', () => {
camera.aspect = window.innerWidth / window.innerHeight;
camera.updateProjectionMatrix();
renderer.setSize(window.innerWidth, window.innerHeight);
});
// Animation loop
function animate() {
requestAnimationFrame(animate);
renderer.render(scene, camera);
}
// Wait for splat to load
setTimeout(() => {
statusEl.textContent = 'Loaded successfully!';
loadingEl.style.display = 'none';
document.getElementById('info').style.opacity = '0.7';
animate();
}, 1000);
</script>
</body>
</html>
"""
HTML_TEMPLATE = """
<!DOCTYPE html>
<html>
<head>
<title>Image to 3D</title>
<style>
* { margin: 0; padding: 0; box-sizing: border-box; }
body {
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
min-height: 100vh;
display: flex;
align-items: center;
justify-content: center;
padding: 20px;
}
.container {
background: white;
border-radius: 20px;
box-shadow: 0 20px 60px rgba(0,0,0,0.3);
max-width: 900px;
width: 100%;
padding: 40px;
}
h1 { font-size: 32px; color: #333; margin-bottom: 10px; }
.subtitle { color: #666; font-size: 14px; margin-bottom: 30px; }
.note { background: #fff3cd; border-left: 4px solid #ffc107; padding: 12px; margin-bottom: 20px; font-size: 13px; color: #856404; }
.drop-zone { border: 3px dashed #667eea; border-radius: 12px; padding: 60px 20px; text-align: center; cursor: pointer; transition: all 0.3s; background: #f8f9ff; margin-bottom: 20px; }
.drop-zone:hover, .drop-zone.dragover { border-color: #764ba2; background: #f0f0ff; transform: scale(1.02); }
.drop-zone-text { color: #667eea; font-size: 18px; font-weight: 600; margin-bottom: 8px; }
.drop-zone-subtext { color: #999; font-size: 14px; }
#preview { max-width: 100%; max-height: 400px; border-radius: 8px; display: none; margin: 20px auto; }
.status { padding: 15px; border-radius: 8px; display: none; margin-top: 15px; font-size: 14px; }
.status.info { background: #e3f2fd; color: #1565c0; display: block; }
.status.success { background: #e8f5e9; color: #2e7d32; display: block; }
.status.error { background: #ffebee; color: #c62828; display: block; }
.result-box { display: none; margin-top: 30px; padding: 20px; background: #f8f9ff; border-radius: 12px; }
.result-box.show { display: block; }
.result-title { font-size: 18px; font-weight: 600; color: #333; margin-bottom: 15px; }
.btn-group { display: flex; gap: 10px; flex-wrap: wrap; }
.btn { background: #667eea; border: none; padding: 12px 24px; border-radius: 8px; cursor: pointer; font-size: 14px; font-weight: 600; color: white; transition: all 0.2s; text-decoration: none; display: inline-block; }
.btn:hover { background: #764ba2; transform: translateY(-2px); box-shadow: 0 4px 12px rgba(0,0,0,0.2); }
.btn-secondary { background: white; color: #667eea; border: 2px solid #667eea; }
.btn-secondary:hover { background: #667eea; color: white; }
</style>
</head>
<body>
<div class="container">
<div>
<h1>🎨 Image to 3D scene</h1>
<p class="subtitle">Convert any image to a 3D scene using Apple's SHARP</p>
</div>
<div class="note">⚠️ <strong>Note:</strong> SHARP generates basic RGB-only splats for speed. The output quality is limited compared to full training approaches, but processes quickly.</div>
<div class="drop-zone" id="dropZone">
<div class="drop-zone-text">📸 Drop an image here</div>
<div class="drop-zone-subtext">or click to browse (JPG, PNG)</div>
<input type="file" id="fileInput" accept="image/*" style="display: none;">
</div>
<img id="preview" alt="Preview">
<div id="status" class="status"></div>
<div id="resultBox" class="result-box">
<div class="result-title">✨ 3D Model Ready!</div>
<div class="btn-group">
<a class="btn" id="downloadBtn" download>⬇️ Download .ply</a>
<a class="btn btn-secondary" id="viewBtn" target="_blank">👁️ View in 3D</a>
</div>
</div>
</div>
<script>
const dropZone = document.getElementById('dropZone');
const fileInput = document.getElementById('fileInput');
const preview = document.getElementById('preview');
const status = document.getElementById('status');
const resultBox = document.getElementById('resultBox');
const downloadBtn = document.getElementById('downloadBtn');
const viewBtn = document.getElementById('viewBtn');
dropZone.addEventListener('click', () => fileInput.click());
dropZone.addEventListener('dragover', (e) => { e.preventDefault(); dropZone.classList.add('dragover'); });
dropZone.addEventListener('dragleave', () => { dropZone.classList.remove('dragover'); });
dropZone.addEventListener('drop', (e) => { e.preventDefault(); dropZone.classList.remove('dragover'); if (e.dataTransfer.files.length) handleFile(e.dataTransfer.files[0]); });
fileInput.addEventListener('change', (e) => { if (e.target.files.length) handleFile(e.target.files[0]); });
function showStatus(message, type) { status.innerHTML = message; status.className = 'status ' + type; if (type !== 'success') resultBox.classList.remove('show'); }
function handleFile(file) {
if (!file.type.startsWith('image/')) { showStatus('❌ Please upload an image file', 'error'); return; }
const reader = new FileReader();
reader.onload = (e) => { preview.src = e.target.result; preview.style.display = 'block'; uploadImage(file); };
reader.readAsDataURL(file);
}
function uploadImage(file) {
const formData = new FormData(); formData.append('image', file);
showStatus('⚙️ Processing image...', 'info');
fetch('/process', { method: 'POST', body: formData })
.then(response => response.json())
.then(data => {
if (data.success) {
showStatus('✅ 3D view generated!', 'success');
downloadBtn.href = '/static/' + data.filename; downloadBtn.download = data.filename; viewBtn.href = '/viewer?file=' + encodeURIComponent(data.filename); resultBox.classList.add('show');
} else showStatus('❌ Error: ' + data.error, 'error');
})
.catch(error => { showStatus('❌ Error: ' + error.message, 'error'); });
}
</script>
</body>
</html>
"""
@app.route('/')
def index():
return render_template_string(HTML_TEMPLATE)
@app.route('/viewer')
def viewer():
filename = request.args.get('file', '')
return render_template_string(VIEWER_TEMPLATE, filename=filename)
@app.route('/process', methods=['POST'])
def process_image():
try:
if 'image' not in request.files:
return jsonify({'success': False, 'error': 'No image provided'})
file = request.files['image']
if file.filename == '':
return jsonify({'success': False, 'error': 'No image selected'})
image_data = file.read()
image = Image.open(io.BytesIO(image_data))
if image.mode != 'RGB':
image = image.convert('RGB')
image_np = np.array(image)
height, width = image_np.shape[:2]
f_px = float(max(width, height))
LOGGER.info(f"Processing: {width}x{height}, f_px={f_px}")
with torch.no_grad():
gaussians = predict_image(predictor, image_np, f_px, device)
output_filename = f"splat_{hash(file.filename) & 0x7FFFFFFF}.ply"
output_path = Path('static') / output_filename
output_path.parent.mkdir(exist_ok=True)
save_ply(gaussians, f_px, (height, width), output_path)
LOGGER.info(f"Saved to {output_path}")
return jsonify({'success': True, 'filename': output_filename})
except Exception as e:
LOGGER.error(f"Error: {str(e)}", exc_info=True)
return jsonify({'success': False, 'error': str(e)})
if __name__ == '__main__':
LOGGER.info("Starting Flask app...")
load_model()
LOGGER.info("Navigate to http://localhost:5050")
app.run(debug=False, host='0.0.0.0', port=5050)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment