Source code for csm.client

import os
import time
import warnings
from urllib.request import urlretrieve
from dataclasses import dataclass
import requests
import base64
import PIL.Image
from io import BytesIO


[docs]class BackendClient: r"""A backend client class for raw GET/POST requests to the REST API. .. warning:: This class should not be accessed directly. Instead, use :class:`CSMClient` to interface with the API. Parameters ---------- api_key : str, optional API key for the CSM account you would like to use. If not provided, the environment variable :envvar:`CSM_API_KEY` is used instead. base_url : str Base url for the API. In general this should not be modified; it is included only for debugging purposes. """ def __init__( self, api_key=None, base_url="https://api.csm.ai", ): if api_key is None: api_key = os.environ.get('CSM_API_KEY') if api_key is None: raise Exception( "The argument `api_key` must be provided when env variable " "CSM_API_KEY is not set." ) self.api_key = api_key self.base_url = base_url @property def headers(self): return { 'Content-Type': 'application/json', 'x-api-key': self.api_key, } def _check_http_response(self, response: requests.Response): if not (200 <= response.status_code < 300): raise RuntimeError( f"HTTP request failed with status code {response.status_code} " f"({response.reason})" ) # image-to-3d API # -------------------------------------------
[docs] def create_image_to_3d_session( self, image_url, *, generate_preview_mesh=False, auto_refine=False, ## refine args creativity="lowest", refine_speed="fast", polygon_count="high_poly", topology="tris", texture_resolution=2048, scaled_bbox=[], pivot_point=[0.0, 0.0, 0.0] ): assert creativity in ["lowest", "moderate", "highest"] assert refine_speed in ["slow", "fast"] assert polygon_count in ["low_poly", "high_poly"] assert topology in ["tris", "quads"] assert 128 <= texture_resolution <= 2048 parameters = { "image_url": image_url, "generate_preview_mesh": generate_preview_mesh, "creativity": creativity, "refine_speed": refine_speed, "resolution": polygon_count, "auto_refine": auto_refine, "topology": topology, "texture_resolution": texture_resolution, "manual_segmentation": False, # TODO: implement this option "pivot_point": [float(s) for s in pivot_point] } if len(scaled_bbox) == 3: parameters["scaled_bbox"] = [float(s) for s in scaled_bbox] response = requests.post( url=f"{self.base_url}/image-to-3d-sessions", json=parameters, headers=self.headers, ) self._check_http_response(response) # expected=201 return response.json()
[docs] def get_image_to_3d_session_info(self, session_code): response = requests.get( url=f"{self.base_url}/image-to-3d-sessions/{session_code}", headers=self.headers, ) self._check_http_response(response) # expected=200 return response.json()
[docs] def get_3d_refine(self, session_code, scaled_bbox=[], pivot_point=[0.0, 0.0, 0.0]): parameters = {"pivot_point": [s for s in pivot_point]} if len(scaled_bbox) == 3: parameters["scaled_bbox"] = [float(s) for s in scaled_bbox] response = requests.post( url=f"{self.base_url}/image-to-3d-sessions/get-3d/refine/{session_code}", json=parameters, headers=self.headers, ) return response.json()
[docs] def get_3d_preview(self, session_code, spin_url=None, scaled_bbox=[], pivot_point=[0.0, 0.0, 0.0]): selected_spin_index = 0 if spin_url is None: result = self.get_image_to_3d_session_info(session_code) spin_url = result['data']['spins'][selected_spin_index]["image_url"] parameters = { "selected_spin_index": selected_spin_index, "selected_spin": spin_url, "pivot_point": [s for s in pivot_point] } if len(scaled_bbox) == 3: parameters["scaled_bbox"] = [float(s) for s in scaled_bbox] response = requests.post( url=f"{self.base_url}/image-to-3d-sessions/get-3d/preview/{session_code}", json=parameters, headers=self.headers, timeout=100, ) self._check_http_response(response) # expected=200 return response.json()
# text-to-image API methods # -------------------------------------------
[docs] def create_text_to_image_session( self, prompt, style_id="", guidance=6, ): parameters = { 'prompt': str(prompt), 'style_id': str(style_id), 'guidance': str(guidance), } response = requests.post( url=f"{self.base_url}/tti-sessions", json=parameters, headers=self.headers, ) self._check_http_response(response) return response.json()
[docs] def get_text_to_image_session_info(self, session_code): response = requests.get( url=f"{self.base_url}/tti-sessions/{session_code}", headers=self.headers, ) self._check_http_response(response) return response.json()
[docs]@dataclass class ImageTo3DResult: r""" Output class for image-to-3d generation. Parameters ---------- session_code : str The image-to-3d session code. mesh_path : str Local path of the generated mesh file. """ session_code: str mesh_path: str
[docs]@dataclass class TextTo3DResult: r""" Output class for text-to-3d generation. Parameters ---------- session_code : str The image-to-3d session code. mesh_path : str Local path of the generated mesh file. image_path : str Local path of the image generated as part of text-to-3d. """ session_code: str mesh_path: str image_path: str
[docs]class CSMClient: r"""Core client utility for accessing the CSM API. Parameters ---------- api_key : str, optional API key for the CSM account you would like to use. If not provided, the environment variable :envvar:`CSM_API_KEY` is used instead. base_url : str Base url for the API. In general this should not be modified; it is included only for debugging purposes. """ def __init__( self, api_key=None, base_url="https://api.csm.ai", ): self.backend = BackendClient(api_key=api_key, base_url=base_url) def _handle_image_input(self, image): if isinstance(image, str): if os.path.isfile(image): # local file path image_path = image pil_image = PIL.Image.open(image_path) else: # URL for a web file image_url = image return image_url # TODO: verify that this is a valid URL elif isinstance(image, PIL.Image.Image): pil_image = image else: raise ValueError(f"Encountered unexpected type for the input image.") return pil_image_to_x64(pil_image)
[docs] def image_to_3d( self, image, *, generate_spin_video=False, mesh_format='obj', output='./', timeout=200, verbose=True, scaled_bbox=[], pivot_point=[0.0, 0.0, 0.0], refine_speed="fast" ): r"""Generate a 3D mesh from an image. The input image can be provided as a URL, a local path, or a :class:`PIL.Image.Image`. Parameters ---------- image : str or PIL.Image.Image The input image. May be provided as a url, a local path, or a :class:`PIL.Image.Image` instance. Returns ------- ImageTo3DResult Result object. Contains the local path of the generated mesh file. """ if generate_spin_video: warnings.warn( "The option `generate_spin_video=True` is deprecated and will be removed " "in a future release", DeprecationWarning) mesh_format = mesh_format.lower() if mesh_format not in ['obj', 'glb', 'usdz']: raise ValueError( f"Unexpected mesh_format value ('{mesh_format}'). Please choose " f"from options ['obj', 'glb', 'usdz']." ) image_url = self._handle_image_input(image) os.makedirs(output, exist_ok=True) # initialize session result = self.backend.create_image_to_3d_session( image_url, generate_preview_mesh=not generate_spin_video, auto_refine=False, scaled_bbox=scaled_bbox, pivot_point=pivot_point, refine_speed=refine_speed ) status = result['data']['status'] if ( (generate_spin_video and status not in ["spin_generate_processing", "spin_generate_done"]) or (not generate_spin_video and status not in ["training_preview", "preview_done"]) ): raise RuntimeError(f"Image-to-3d session creation failed (status='{status}')") session_code = result['data']['session_code'] step_label = "spin generation" if generate_spin_video else "mesh generation" if verbose: print(f'[INFO] Image-to-3d session created ({session_code})') if generate_spin_video: if verbose: print(f'[INFO] Running preview {step_label}...') # wait for preview spin generation to complete (20-30s) start_time = time.time() run_time = 0. while True: time.sleep(2) result = self.backend.get_image_to_3d_session_info(session_code) status = result['data']['status'] if status == 'spin_generate_done': break elif status == 'spin_generate_failed': raise RuntimeError(f"Preview {step_label} failed") elif status != 'spin_generate_processing': raise RuntimeError(f"Unexpected error during preview {step_label} (status='{status}')") run_time = time.time() - start_time if run_time >= timeout: raise RuntimeError(f"Preview {step_label} timed out") if verbose: print(f'[INFO] Preview {step_label} completed in {run_time:.1f}s') spin_url = result['data']['spins'][0]["image_url"] # launch preview mesh export result = self.backend.get_3d_preview( session_code, spin_url=spin_url, scaled_bbox=scaled_bbox, pivot_point=pivot_point ) step_label = "mesh export" if verbose: print(f'[INFO] Running preview {step_label}...') # wait for preview mesh export to complete (20-30s) start_time = time.time() run_time = 0. while True: time.sleep(2) result = self.backend.get_image_to_3d_session_info(session_code) status = result['data']['status'] if status == 'preview_done': break elif status == 'preview_failed': raise RuntimeError(f"Preview {step_label} failed.") elif status != 'training_preview': raise RuntimeError(f"Unexpected error during preview {step_label} (status='{status}')") run_time = time.time() - start_time if run_time >= timeout: raise RuntimeError(f"Preview {step_label} timed out") if verbose: print(f'[INFO] Preview {step_label} completed in {run_time:.1f}s') # download mesh file based on the requested format if mesh_format == 'obj': mesh_url = result['data']['preview_mesh_url_zip'] mesh_file = 'mesh.zip' elif mesh_format == 'glb': mesh_url = result['data']['preview_mesh_url_glb'] mesh_file = 'mesh.glb' elif mesh_format == 'usdz': mesh_url = result['data']['preview_mesh_url_usdz'] mesh_file = 'mesh.usdz' else: raise ValueError(f"Encountered unexpected mesh_format value ('{mesh_format}').") mesh_path = os.path.join(output, mesh_file) # TODO: os.path.abspath ? urlretrieve(mesh_url, mesh_path) return ImageTo3DResult(session_code=session_code, mesh_path=mesh_path)
[docs] def text_to_3d( self, prompt, *, style_id="", guidance=6, generate_spin_video=False, mesh_format='obj', output='./', timeout=200, verbose=True, scaled_bbox=[], pivot_point=[0.0, 0.0, 0.0], refine_speed="fast" ): r"""Generate a 3D mesh from a text prompt. Parameters ---------- prompt : str The input text prompt. Returns ------- TextTo3DResult Result object. Contains the local path of the generated mesh file, as well as the image that was generated as part of the pipeline. """ os.makedirs(output, exist_ok=True) # initialize text-to-image session result = self.backend.create_text_to_image_session( prompt, style_id=style_id, guidance=guidance, ) status = result['data']['status'] if status != "processing" and status != "completed": raise RuntimeError(f"Text-to-image session creation failed (status='{status}')") session_code = result['data']['session_code'] if verbose: print(f'[INFO] Text-to-image session created ({session_code})') print(f'[INFO] Running text-to-image generation...') # wait for image generation to complete start_time = time.time() run_time = 0. while True: time.sleep(2) result = self.backend.get_text_to_image_session_info(session_code) status = result['data']['status'] if status == 'completed': break elif status != 'processing': raise RuntimeError(f"Unexpected error during text-to-image generation (status='{status}')") run_time = time.time() - start_time if run_time >= timeout: raise RuntimeError("Text-to-image generation timed out") if verbose: print(f'[INFO] Text-to-image generation completed in {run_time:.1f}s') # access the image URL image_url = result['data']['image_url'] # download image image_path = os.path.join(output, 'image.png') urlretrieve(image_url, image_path) # launch image-to-3d i23 = self.image_to_3d( image_url, generate_spin_video=generate_spin_video, mesh_format=mesh_format, output=output, timeout=timeout, verbose=verbose, scaled_bbox=scaled_bbox, pivot_point=pivot_point, refine_speed=refine_speed ) return TextTo3DResult(session_code=i23.session_code, mesh_path=i23.mesh_path, image_path=image_path)
def pil_image_to_x64(image: PIL.Image.Image) -> str: """PIL.Image.Image to base64""" buffer = BytesIO() image.save(buffer, "PNG") x64 = buffer.getvalue() return 'data:image/png;base64,' + base64.b64encode(x64).decode("utf-8")